diff --git a/src/transformers/pipelines/audio_classification.py b/src/transformers/pipelines/audio_classification.py index b3a3c7f948a089..ee6b09e06b36e6 100644 --- a/src/transformers/pipelines/audio_classification.py +++ b/src/transformers/pipelines/audio_classification.py @@ -12,23 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import subprocess -from typing import TYPE_CHECKING, Optional, Union +from typing import Union import numpy as np -from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..file_utils import add_end_docstrings, is_torch_available from ..utils import logging from .base import PIPELINE_INIT_ARGS, Pipeline -if TYPE_CHECKING: - from ..modeling_tf_utils import TFPreTrainedModel - from ..modeling_utils import PreTrainedModel - if is_torch_available(): - import torch - from ..models.auto.modeling_auto import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING logger = logging.get_logger(__name__) @@ -84,14 +77,10 @@ class AudioClassificationPipeline(Pipeline): `__. """ - def __init__( - self, - model: Union["PreTrainedModel", "TFPreTrainedModel"], - feature_extractor: PreTrainedFeatureExtractor, - framework: Optional[str] = None, - **kwargs - ): - super().__init__(model, feature_extractor=feature_extractor, framework=framework, **kwargs) + top_k = 5 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) if self.framework != "pt": raise ValueError(f"The {self.__class__} is only available in PyTorch.") @@ -101,7 +90,6 @@ def __init__( def __call__( self, inputs: Union[np.ndarray, bytes, str], - top_k: Optional[int] = None, **kwargs, ): """ @@ -126,6 +114,16 @@ def __call__( - **label** (:obj:`str`) -- The label predicted. - **score** (:obj:`float`) -- The corresponding probability. """ + return super().__call__(inputs, **kwargs) + + def set_parameters(self, top_k=None, **kwargs): + # No parameters on this pipeline right now + if top_k is not None: + self.top_k = top_k + if self.top_k > self.model.config.num_labels: + self.top_k = self.model.config.num_labels + + def preprocess(self, inputs): if isinstance(inputs, str): with open(inputs, "rb") as f: inputs = f.read() @@ -136,24 +134,23 @@ def __call__( if not isinstance(inputs, np.ndarray): raise ValueError("We expect a numpy ndarray as input") if len(inputs.shape) != 1: - raise ValueError("We expect a single channel audio input for AudioClassificationPipeline") - - if top_k is None or top_k > self.model.config.num_labels: - top_k = self.model.config.num_labels + raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") processed = self.feature_extractor( inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" ) - processed = self.ensure_tensor_on_device(**processed) + return processed - with torch.no_grad(): - outputs = self.model(**processed) + def forward(self, model_inputs): + model_outputs = self.model(**model_inputs) + return model_outputs - probs = outputs.logits[0].softmax(-1) - scores, ids = probs.topk(top_k) + def postprocess(self, model_outputs): + probs = model_outputs.logits[0].softmax(-1) + scores, ids = probs.topk(self.top_k) - scores = scores.tolist() - ids = ids.tolist() + scores = scores.tolist() + ids = ids.tolist() labels = [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]