Skip to content

Commit

Permalink
ENH: Improve handling of optional dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
ghisvail committed May 17, 2024
1 parent 44822f1 commit 1869a9f
Show file tree
Hide file tree
Showing 43 changed files with 294 additions and 216 deletions.
6 changes: 0 additions & 6 deletions medkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
__all__ = ["audio", "core", "io", "text", "tools"]

from medkit.core.utils import modules_are_available

if modules_are_available(["torch"]):
__all__ += ["training"]
52 changes: 52 additions & 0 deletions medkit/_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

import importlib
import inspect
import sys
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import types


__all__ = ["import_optional"]


def import_optional(name: str, extra: str | None = None) -> types.ModuleType:
"""Import an optional dependency or raise an appropriate error message.
Parameters
----------
name : str
Module name to import.
extra : str, optional
Group of optional dependencies to suggest installing if the import fails.
If unspecified, assume the extra is named after the caller's module.
Returns
-------
ModuleType
The successfully imported module.
Raises
------
ModuleNotFoundError
In case the requested import failed.
"""
try:
module = importlib.import_module(name)
except ModuleNotFoundError as err:
if not extra:
calling_module = inspect.getmodulename(inspect.stack()[1][1])
extra = calling_module.replace("_", "-") if calling_module else None

note = f"Consider installing the appropriate extra with:\npip install 'medkit-lib[{extra}]'" if extra else None

if sys.version_info >= (3, 11):
if note:
err.add_note(note)
raise

message = "\n".join([str(err), note or ""])
raise ModuleNotFoundError(message) from err
return module
9 changes: 0 additions & 9 deletions medkit/audio/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@
__all__ = []

from medkit.core.utils import modules_are_available

if modules_are_available(["pyannote"]) and modules_are_available(["pyannote.core", "pyannote.metrics"]):
__all__ += ["diarization"]

if modules_are_available(["speechbrain"]):
__all__ += ["transcription"]
13 changes: 9 additions & 4 deletions medkit/audio/metrics/diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
# we import pandas manually first.
# So as a workaround, we always import pandas before importing something from pyannote
import pandas as pd # noqa: F401
from pyannote.core.annotation import Annotation as PAAnnotation
from pyannote.core.annotation import Segment as PASegment
from pyannote.core.annotation import Timeline as PATimeline
from pyannote.metrics.diarization import GreedyDiarizationErrorRate

from medkit._import import import_optional
from medkit.core.audio import AudioDocument, Segment

_ = import_optional("pyannote.core", extra="metrics-diarization")
_ = import_optional("pyannote.metrics", extra="metrics-diarization")

from pyannote.core.annotation import Annotation as PAAnnotation # noqa: E402
from pyannote.core.annotation import Segment as PASegment # noqa: E402
from pyannote.core.annotation import Timeline as PATimeline # noqa: E402
from pyannote.metrics.diarization import GreedyDiarizationErrorRate # noqa: E402

logger = logging.getLogger(__name__)


Expand Down
7 changes: 5 additions & 2 deletions medkit/audio/metrics/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import string
from typing import TYPE_CHECKING, Sequence

from speechbrain.utils.metric_stats import ErrorRateStats

from medkit._import import import_optional
from medkit.text.utils.decoding import get_ascii_from_unicode

_ = import_optional("speechbrain", extra="metrics-transcription")

if TYPE_CHECKING:
from medkit.core.audio import AudioDocument, Segment

Expand Down Expand Up @@ -136,6 +137,8 @@ def compute(
TranscriptionEvaluatorResult
Computed metrics
"""
from speechbrain.utils.metric_stats import ErrorRateStats

if len(reference) != len(predicted):
msg = "Reference and predicted must have the same length"
raise ValueError(msg)
Expand Down
9 changes: 6 additions & 3 deletions medkit/audio/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from medkit.audio.preprocessing.downmixer import Downmixer
from medkit.audio.preprocessing.power_normalizer import PowerNormalizer
from medkit.core.utils import modules_are_available

if modules_are_available(["resampy"]):
__all__ += ["resampler"]
try:
from medkit.audio.preprocessing.resampler import Resampler

__all__ += ["Resampler"]
except ModuleNotFoundError:
pass
4 changes: 3 additions & 1 deletion medkit/audio/preprocessing/resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

__all__ = ["Resampler"]

import resampy

from medkit._import import import_optional
from medkit.core.audio import MemoryAudioBuffer, PreprocessingOperation, Segment

resampy = import_optional("resampy")


class Resampler(PreprocessingOperation):
"""Resampling operation relying on the resampy package.
Expand Down
16 changes: 11 additions & 5 deletions medkit/audio/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
__all__ = []

from medkit.core.utils import modules_are_available
try:
from medkit.audio.segmentation.pa_speaker_detector import PASpeakerDetector

if modules_are_available(["webrtcvad"]):
__all__ += ["webrtc_voice_detector"]
__all__ += ["PASpeakerDetector"]
except ModuleNotFoundError:
pass

if modules_are_available(["pyannote"]) and modules_are_available(["torch", "pyannote.audio"]):
__all__ += ["pa_speaker_detector"]
try:
from medkit.audio.segmentation.webrtc_voice_detector import WebRTCVoiceDetector

__all__ += ["WebRTCVoiceDetector"]
except ModuleNotFoundError:
pass
12 changes: 7 additions & 5 deletions medkit/audio/segmentation/pa_speaker_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# we import pandas manually first.
# So as a workaround, we always import pandas before importing something from pyannote
import pandas as pd # noqa: F401
import torch
from pyannote.audio import Pipeline
from pyannote.audio.pipelines import SpeakerDiarization

from medkit._import import import_optional
from medkit.core import Attribute
from medkit.core.audio import Segment, SegmentationOperation, Span

audio = import_optional("pyannote.audio")
torch = import_optional("torch")


if TYPE_CHECKING:
from pathlib import Path

Expand Down Expand Up @@ -102,11 +104,11 @@ def __init__(
self.min_duration = min_duration

torch_device = torch.device("cpu" if device < 0 else f"cuda:{device}")
self._pipeline = Pipeline.from_pretrained(model, use_auth_token=hf_auth_token)
self._pipeline = audio.Pipeline.from_pretrained(model, use_auth_token=hf_auth_token)
if self._pipeline is None:
msg = f"Could not instantiate pretrained pipeline with '{model}'"
raise ValueError(msg)
if not isinstance(self._pipeline, SpeakerDiarization):
if not isinstance(self._pipeline, audio.pipelines.SpeakerDiarization):
msg = (
f"'{model}' does not correspond to a SpeakerDiarization pipeline. Got"
f" object of type {type(self._pipeline)}"
Expand Down
4 changes: 3 additions & 1 deletion medkit/audio/segmentation/webrtc_voice_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from typing import Iterator

import numpy as np
import webrtcvad
from typing_extensions import Literal

from medkit._import import import_optional
from medkit.core.audio import Segment, SegmentationOperation, Span

webrtcvad = import_optional("webrtcvad")

_SUPPORTED_SAMPLE_RATES = {8000, 16000, 32000, 48000}


Expand Down
9 changes: 6 additions & 3 deletions medkit/audio/transcription/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from medkit.audio.transcription.doc_transcriber import DocTranscriber, TranscriptionOperation
from medkit.audio.transcription.transcribed_text_document import TranscribedTextDocument
from medkit.core.utils import modules_are_available

__all__ = [
"DocTranscriber",
"TranscriptionOperation",
"TranscribedTextDocument",
]

if modules_are_available(["torchaudio", "transformers"]):
try:
from medkit.audio.transcription.hf_transcriber import HFTranscriber

__all__ += ["HFTranscriber"]
except ModuleNotFoundError:
pass

if modules_are_available(["torch", "speechbrain"]):
try:
from medkit.audio.transcription.sb_transcriber import SBTranscriber

__all__ += ["SBTranscriber"]
except ModuleNotFoundError:
pass
8 changes: 4 additions & 4 deletions medkit/audio/transcription/hf_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from typing import TYPE_CHECKING

import transformers
from transformers import AutomaticSpeechRecognitionPipeline

from medkit._import import import_optional
from medkit.core import Attribute, Operation

transformers = import_optional("transformers")

if TYPE_CHECKING:
from pathlib import Path

Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
task=task,
model=self.model_name,
feature_extractor=self.model_name,
pipeline_class=AutomaticSpeechRecognitionPipeline,
pipeline_class=transformers.AutomaticSpeechRecognitionPipeline,
device=self.device,
batch_size=batch_size,
token=hf_auth_token,
Expand Down
11 changes: 7 additions & 4 deletions medkit/audio/transcription/sb_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from pathlib import Path
from typing import TYPE_CHECKING

import speechbrain as sb

from medkit._compat import batched
from medkit._import import import_optional
from medkit.core import Attribute, Operation

speechbrain = import_optional("speechbrain")

if TYPE_CHECKING:
from medkit.core.audio import AudioBuffer, Segment

Expand Down Expand Up @@ -88,7 +89,7 @@ def __init__(
self.batch_size = batch_size
self._torch_device = "cpu" if self.device < 0 else f"cuda:{self.device}"

asr_class = sb.pretrained.EncoderDecoderASR if needs_decoder else sb.pretrained.EncoderASR
asr_class = speechbrain.pretrained.EncoderDecoderASR if needs_decoder else speechbrain.pretrained.EncoderASR

self._asr = asr_class.from_hparams(source=model, savedir=cache_dir, run_opts={"device": self._torch_device})

Expand Down Expand Up @@ -129,7 +130,9 @@ def _transcribe_audios(self, audios: list[AudioBuffer]) -> list[str]:

# group audios in batch of same length with padding
for batched_audios in batched(audios, self.batch_size):
padded_batch = sb.dataio.batch.PaddedBatch([{"wav": a.read().reshape((-1,))} for a in batched_audios])
padded_batch = speechbrain.dataio.batch.PaddedBatch(
[{"wav": a.read().reshape((-1,))} for a in batched_audios]
)
padded_batch.to(self._torch_device)

batch_texts, _ = self._asr.transcribe_batch(padded_batch.wav.data, padded_batch.wav.lengths)
Expand Down
11 changes: 0 additions & 11 deletions medkit/core/utils.py

This file was deleted.

9 changes: 0 additions & 9 deletions medkit/text/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@
__all__ = []

from medkit.core.utils import modules_are_available

if modules_are_available(["seqeval", "transformers", "torch"]):
__all__ += ["ner"]

if modules_are_available(["sklearn"]):
__all__ += ["classification"]
8 changes: 5 additions & 3 deletions medkit/text/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import logging
from typing import TYPE_CHECKING

from sklearn.metrics import classification_report, cohen_kappa_score
from typing_extensions import Literal

from medkit._import import import_optional
from medkit.text.metrics.irr_utils import krippendorff_alpha

metrics = import_optional("sklearn.metrics", extra="metrics-text-classification")

if TYPE_CHECKING:
from medkit.core.text import TextDocument

Expand Down Expand Up @@ -100,7 +102,7 @@ def compute_classification_report(
true_tags = self._extract_attr_values(true_docs)
pred_tags = self._extract_attr_values(predicted_docs)

report = classification_report(
report = metrics.classification_report(
y_true=true_tags,
y_pred=pred_tags,
output_dict=True,
Expand Down Expand Up @@ -150,7 +152,7 @@ def compute_cohen_kappa(
ann2_tags = self._extract_attr_values(docs_annotator_2)

return {
"cohen_kappa": cohen_kappa_score(y1=ann1_tags, y2=ann2_tags),
"cohen_kappa": metrics.cohen_kappa_score(y1=ann1_tags, y2=ann2_tags),
"support": len(ann1_tags),
}

Expand Down
12 changes: 7 additions & 5 deletions medkit/text/metrics/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

from typing import TYPE_CHECKING, Any

from seqeval.metrics import accuracy_score, classification_report
from seqeval.scheme import BILOU, IOB2
from typing_extensions import Literal

from medkit._import import import_optional
from medkit.core.text import Entity, TextDocument, span_utils
from medkit.text.ner import hf_tokenization_utils

metrics = import_optional("seqeval.metrics", extra="metrics-ner")
scheme_ = import_optional("seqeval.scheme", extra="metrics-ner")

if TYPE_CHECKING:
from medkit.training.utils import BatchData

Expand All @@ -25,11 +27,11 @@ def _compute_seqeval_from_dict(
"""Compute seqeval metrics using preprocessed data."""
# internal configuration for seqeval
# 'bilou' only works with 'strict' mode
scheme = BILOU if tagging_scheme == "bilou" else IOB2
scheme = scheme_.BILOU if tagging_scheme == "bilou" else scheme_.IOB2
mode = "strict" if tagging_scheme == "bilou" else None

# returns precision, recall, F1 score for each class.
report = classification_report(
report = metrics.classification_report(
y_true=y_true_all,
y_pred=y_pred_all,
scheme=scheme,
Expand All @@ -40,7 +42,7 @@ def _compute_seqeval_from_dict(
# add average metrics
scores = {f"{average}_{key}": value for key, value in report[f"{average} avg"].items()}
scores["support"] = scores.pop(f"{average}_support")
scores["accuracy"] = accuracy_score(y_true=y_true_all, y_pred=y_pred_all)
scores["accuracy"] = metrics.accuracy_score(y_true=y_true_all, y_pred=y_pred_all)

if return_metrics_by_label:
for value_key in report:
Expand Down
Loading

0 comments on commit 1869a9f

Please sign in to comment.