Skip to content

Commit

Permalink
Replace 'decord' with 'av' in VideoClassificationPipeline (#29747)
Browse files Browse the repository at this point in the history
* replace the 'decord' with 'av' in VideoClassificationPipeline

* fix the check of backend in VideoClassificationPipeline

* adjust the order of imports

* format 'video_classification.py'

* format 'video_classification.py' with ruff

---------

Co-authored-by: wanqiancheng <13541261013@163.com>
  • Loading branch information
2 people authored and Ita Zaporozhets committed May 14, 2024
1 parent e0bc2f7 commit 9890fb1
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,7 @@
"add_end_docstrings",
"add_start_docstrings",
"is_apex_available",
"is_av_available",
"is_bitsandbytes_available",
"is_datasets_available",
"is_decord_available",
Expand Down Expand Up @@ -5951,6 +5952,7 @@
add_end_docstrings,
add_start_docstrings,
is_apex_available,
is_av_available,
is_bitsandbytes_available,
is_datasets_available,
is_decord_available,
Expand Down
32 changes: 25 additions & 7 deletions src/transformers/pipelines/video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@

import requests

from ..utils import add_end_docstrings, is_decord_available, is_torch_available, logging, requires_backends
from ..utils import (
add_end_docstrings,
is_av_available,
is_torch_available,
logging,
requires_backends,
)
from .base import Pipeline, build_pipeline_init_args


if is_decord_available():
if is_av_available():
import av
import numpy as np
from decord import VideoReader


if is_torch_available():
Expand All @@ -33,7 +39,7 @@ class VideoClassificationPipeline(Pipeline):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
requires_backends(self, "decord")
requires_backends(self, "av")
self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES)

def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None):
Expand Down Expand Up @@ -90,14 +96,13 @@ def preprocess(self, video, num_frames=None, frame_sampling_rate=1):
if video.startswith("http://") or video.startswith("https://"):
video = BytesIO(requests.get(video).content)

videoreader = VideoReader(video)
videoreader.seek(0)
container = av.open(video)

start_idx = 0
end_idx = num_frames * frame_sampling_rate - 1
indices = np.linspace(start_idx, end_idx, num=num_frames, dtype=np.int64)

video = videoreader.get_batch(indices).asnumpy()
video = read_video_pyav(container, indices)
video = list(video)

model_inputs = self.image_processor(video, return_tensors=self.framework)
Expand All @@ -120,3 +125,16 @@ def postprocess(self, model_outputs, top_k=5):
scores = scores.tolist()
ids = ids.tolist()
return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]


def read_video_pyav(container, indices):
frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
is_av_available,
is_bitsandbytes_available,
is_bs4_available,
is_cv2_available,
Expand Down Expand Up @@ -1010,6 +1011,13 @@ def require_aqlm(test_case):
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)


def require_av(test_case):
"""
Decorator marking a test that requires av
"""
return unittest.skipUnless(is_av_available(), "test requires av")(test_case)


def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
is_av_available,
is_bitsandbytes_available,
is_bs4_available,
is_coloredlogs_available,
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
_aqlm_available = _is_package_available("aqlm")
_av_available = importlib.util.find_spec("av") is not None
_bitsandbytes_available = _is_package_available("bitsandbytes")
_galore_torch_available = _is_package_available("galore_torch")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
Expand Down Expand Up @@ -656,6 +657,10 @@ def is_aqlm_available():
return _aqlm_available


def is_av_available():
return _av_available


def is_ninja_available():
r"""
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
Expand Down Expand Up @@ -1012,6 +1017,16 @@ def is_mlx_available():
return _mlx_available


# docstyle-ignore
AV_IMPORT_ERROR = """
{0} requires the PyAv library but it was not found in your environment. You can install it with:
```
pip install av
```
Please note that you may need to restart your runtime after installation.
"""


# docstyle-ignore
CV2_IMPORT_ERROR = """
{0} requires the OpenCV library but it was not found in your environment. You can install it with:
Expand Down Expand Up @@ -1336,6 +1351,7 @@ def is_mlx_available():

BACKENDS_MAPPING = OrderedDict(
[
("av", (is_av_available, AV_IMPORT_ERROR)),
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/test_pipelines_video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_decord,
require_av,
require_tf,
require_torch,
require_torch_or_tf,
Expand All @@ -34,7 +34,7 @@
@is_pipeline_test
@require_torch_or_tf
@require_vision
@require_decord
@require_av
class VideoClassificationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING

Expand Down

0 comments on commit 9890fb1

Please sign in to comment.