Skip to content

Commit

Permalink
Fixing audio-classification for large PR.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Sep 1, 2021
1 parent 7e8f4b9 commit cd6e659
Showing 1 changed file with 25 additions and 28 deletions.
53 changes: 25 additions & 28 deletions src/transformers/pipelines/audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
from typing import TYPE_CHECKING, Optional, Union
from typing import Union

import numpy as np

from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_torch_available
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline


if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel

if is_torch_available():
import torch

from ..models.auto.modeling_auto import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -84,14 +77,10 @@ class AudioClassificationPipeline(Pipeline):
<https://huggingface.co/models?filter=audio-classification>`__.
"""

def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
feature_extractor: PreTrainedFeatureExtractor,
framework: Optional[str] = None,
**kwargs
):
super().__init__(model, feature_extractor=feature_extractor, framework=framework, **kwargs)
top_k = 5

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if self.framework != "pt":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
Expand All @@ -101,7 +90,6 @@ def __init__(
def __call__(
self,
inputs: Union[np.ndarray, bytes, str],
top_k: Optional[int] = None,
**kwargs,
):
"""
Expand All @@ -126,6 +114,16 @@ def __call__(
- **label** (:obj:`str`) -- The label predicted.
- **score** (:obj:`float`) -- The corresponding probability.
"""
return super().__call__(inputs, **kwargs)

def set_parameters(self, top_k=None, **kwargs):
# No parameters on this pipeline right now
if top_k is not None:
self.top_k = top_k
if self.top_k > self.model.config.num_labels:
self.top_k = self.model.config.num_labels

def preprocess(self, inputs):
if isinstance(inputs, str):
with open(inputs, "rb") as f:
inputs = f.read()
Expand All @@ -136,24 +134,23 @@ def __call__(
if not isinstance(inputs, np.ndarray):
raise ValueError("We expect a numpy ndarray as input")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")

if top_k is None or top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")

processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
processed = self.ensure_tensor_on_device(**processed)
return processed

with torch.no_grad():
outputs = self.model(**processed)
def forward(self, model_inputs):
model_outputs = self.model(**model_inputs)
return model_outputs

probs = outputs.logits[0].softmax(-1)
scores, ids = probs.topk(top_k)
def postprocess(self, model_outputs):
probs = model_outputs.logits[0].softmax(-1)
scores, ids = probs.topk(self.top_k)

scores = scores.tolist()
ids = ids.tolist()
scores = scores.tolist()
ids = ids.tolist()

labels = [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]

Expand Down

0 comments on commit cd6e659

Please sign in to comment.