Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace 'decord' with 'av' in VideoClassificationPipeline #29747

Merged
merged 8 commits into from
Mar 26, 2024
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,7 @@
"add_end_docstrings",
"add_start_docstrings",
"is_apex_available",
"is_av_available",
"is_bitsandbytes_available",
"is_datasets_available",
"is_decord_available",
Expand Down Expand Up @@ -5952,6 +5953,7 @@
add_end_docstrings,
add_start_docstrings,
is_apex_available,
is_av_available,
is_bitsandbytes_available,
is_datasets_available,
is_decord_available,
Expand Down
32 changes: 25 additions & 7 deletions src/transformers/pipelines/video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@

import requests

from ..utils import add_end_docstrings, is_decord_available, is_torch_available, logging, requires_backends
from ..utils import (
add_end_docstrings,
is_av_available,
is_torch_available,
logging,
requires_backends,
)
from .base import Pipeline, build_pipeline_init_args


if is_decord_available():
if is_av_available():
import av
import numpy as np
from decord import VideoReader


if is_torch_available():
Expand All @@ -33,7 +39,7 @@ class VideoClassificationPipeline(Pipeline):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
requires_backends(self, "decord")
requires_backends(self, "av")
self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES)

def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None):
Expand Down Expand Up @@ -90,14 +96,13 @@ def preprocess(self, video, num_frames=None, frame_sampling_rate=1):
if video.startswith("http://") or video.startswith("https://"):
video = BytesIO(requests.get(video).content)

videoreader = VideoReader(video)
videoreader.seek(0)
container = av.open(video)

start_idx = 0
end_idx = num_frames * frame_sampling_rate - 1
indices = np.linspace(start_idx, end_idx, num=num_frames, dtype=np.int64)

video = videoreader.get_batch(indices).asnumpy()
video = read_video_pyav(container, indices)
video = list(video)

model_inputs = self.image_processor(video, return_tensors=self.framework)
Expand All @@ -120,3 +125,16 @@ def postprocess(self, model_outputs, top_k=5):
scores = scores.tolist()
ids = ids.tolist()
return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]


def read_video_pyav(container, indices):
frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
is_av_available,
is_bitsandbytes_available,
is_bs4_available,
is_cv2_available,
Expand Down Expand Up @@ -1010,6 +1011,13 @@ def require_aqlm(test_case):
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)


def require_av(test_case):
"""
Decorator marking a test that requires av
"""
return unittest.skipUnless(is_av_available(), "test requires av")(test_case)


def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
is_av_available,
is_bitsandbytes_available,
is_bs4_available,
is_coloredlogs_available,
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
_aqlm_available = _is_package_available("aqlm")
_av_available = importlib.util.find_spec("av") is not None
_bitsandbytes_available = _is_package_available("bitsandbytes")
_galore_torch_available = _is_package_available("galore_torch")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
Expand Down Expand Up @@ -656,6 +657,10 @@ def is_aqlm_available():
return _aqlm_available


def is_av_available():
return _av_available


def is_ninja_available():
r"""
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
Expand Down Expand Up @@ -1012,6 +1017,16 @@ def is_mlx_available():
return _mlx_available


# docstyle-ignore
AV_IMPORT_ERROR = """
{0} requires the PyAv library but it was not found in your environment. You can install it with:
```
pip install av
```
Please note that you may need to restart your runtime after installation.
"""


# docstyle-ignore
CV2_IMPORT_ERROR = """
{0} requires the OpenCV library but it was not found in your environment. You can install it with:
Expand Down Expand Up @@ -1336,6 +1351,7 @@ def is_mlx_available():

BACKENDS_MAPPING = OrderedDict(
[
("av", (is_av_available, AV_IMPORT_ERROR)),
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/test_pipelines_video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_decord,
require_av,
require_tf,
require_torch,
require_torch_or_tf,
Expand All @@ -34,7 +34,7 @@
@is_pipeline_test
@require_torch_or_tf
@require_vision
@require_decord
@require_av
class VideoClassificationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING

Expand Down