-
Notifications
You must be signed in to change notification settings - Fork 64
Description
🐛 Describe the bug
To Reproduce
Running on A10
1. Install torchcodec nightly build:
python3 -m pip uninstall -y torch torchvision torchaudio torchcodec && python3 -m pip install --pre torch torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/nightly/cu126
torch==2.9.0.dev20250906+cu126
torchaudio==2.8.0.dev20250906+cu126
torchcodec==0.7.0.dev20250906+cu126
torchvision==0.24.0.dev20250906+cu126
2. Install transformers
git clone https://github.com/huggingface/transformers.git && cd transformers && git fetch origin && git checkout bb45d3631ec7026db04a77d33a52b31766372160 && pip install -e .[torch,testing]
3. Running test
RUN_SLOW=1 python3 -m pytest -v tests/models/internvl/test_modeling_internvl.py::InternVLQwen2IntegrationTest::test_qwen2_medium_model_integration_video
gives
RuntimeError: torchcodec_ns::get_frames_at_indices() Expected a value of type 'List[int]' for argument 'frame_indices' but instead found type 'Tensor'.
Full log
=============================================================================================== FAILURES ===============================================================================================
________________________________________________________________ InternVLQwen2IntegrationTest.test_qwen2_medium_model_integration_video ________________________________________________________________
self = <tests.models.internvl.test_modeling_internvl.InternVLQwen2IntegrationTest testMethod=test_qwen2_medium_model_integration_video>
@require_av
@require_bitsandbytes
def test_qwen2_medium_model_integration_video(self):
processor = AutoProcessor.from_pretrained(self.medium_model_checkpoint)
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = InternVLForConditionalGeneration.from_pretrained(
self.medium_model_checkpoint, quantization_config=quantization_config
)
# Prepare inputs
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
},
{"type": "text", "text": "What type of shot is the man performing?"},
],
}
]
> inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
num_frames=8,
).to(torch_device, dtype=torch.float16)
tests/models/internvl/test_modeling_internvl.py:483:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/utils/deprecation.py:172: in wrapped_func
return func(*args, **kwargs)
src/transformers/utils/deprecation.py:172: in wrapped_func
return func(*args, **kwargs)
src/transformers/processing_utils.py:1634: in apply_chat_template
out = self(
src/transformers/models/internvl/processing_internvl.py:226: in __call__
video_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
src/transformers/video_processing_utils.py:212: in __call__
return self.preprocess(videos, **kwargs)
src/transformers/video_processing_utils.py:378: in preprocess
videos, video_metadata = self._decode_and_sample_videos(
src/transformers/video_processing_utils.py:328: in _decode_and_sample_videos
videos, video_metadata = self.fetch_videos(videos, sample_indices_fn=sample_indices_fn)
src/transformers/video_processing_utils.py:889: in fetch_videos
return list(zip(*[self.fetch_videos(x, sample_indices_fn=sample_indices_fn) for x in video_url_or_urls]))
src/transformers/video_processing_utils.py:889: in <listcomp>
return list(zip(*[self.fetch_videos(x, sample_indices_fn=sample_indices_fn) for x in video_url_or_urls]))
src/transformers/video_processing_utils.py:891: in fetch_videos
return load_video(video_url_or_urls, backend=backend, sample_indices_fn=sample_indices_fn)
src/transformers/video_utils.py:711: in load_video
video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
src/transformers/video_utils.py:602: in read_video_torchcodec
video = decoder.get_frames_at(indices=indices).data.contiguous()
/usr/local/lib/python3.10/dist-packages/torchcodec/decoders/_video_decoder.py:250: in get_frames_at
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <OpOverload(op='torchcodec_ns.get_frames_at_indices', overload='default')>
args = (tensor([ 94210107637760, 94190266082880, 94190266083584,
94190266083584, 4294967297, 4618565266858196337,
288230376154661121, 4710891217590681600]),)
kwargs = {'frame_indices': tensor([ 11, 34, 56, 79, 102, 125, 147, 170], dtype=torch.int32)}
def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
> return self._op(*args, **kwargs)
E RuntimeError: torchcodec_ns::get_frames_at_indices() Expected a value of type 'List[int]' for argument 'frame_indices' but instead found type 'Tensor'.
E Position: 1
E Value: tensor([ 11, 34, 56, 79, 102, 125, 147, 170], dtype=torch.int32)
E Declaration: torchcodec_ns::get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)
E Cast error details: Unable to cast Python instance of type <class 'torch.Tensor'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)
/usr/local/lib/python3.10/dist-packages/torch/_ops.py:841: RuntimeError
4. Versions
root@d6713c2f5f1d:/transformers# python3 collect_env.py
/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py:63: FutureWarning: The pynvml package
is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
Collecting environment information...
PyTorch version: 2.9.0.dev20250905+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.10.12 (main, Aug 15 2025, 14:32:43) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.10.240-238.959.amzn2.x86_64-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA A10G
Nvidia driver version: 550.163.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.3.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7R32
CPU family: 23
Model: 49
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 0
BogoMIPS: 5600.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cm
ov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc
rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 s
se4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse
4a misalignsse 3dnowprefetch topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdsee
d adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 4 MiB (8 instances)
L3 cache: 32 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsa: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] onnx==1.19.0
[pip3] onnxconverter-common==1.16.0
[pip3] onnxruntime==1.22.1
[pip3] onnxruntime-tools==1.7.0
[pip3] pytorch-triton==3.4.0+gitf7888497
[pip3] tf2onnx==1.8.4
[pip3] torch==2.9.0.dev20250905+cu126
[pip3] torchaudio==2.8.0.dev20250905+cu126
[pip3] torchcodec==0.7.0.dev20250905+cu126
[pip3] torchvision==0.24.0.dev20250905+cu126
[pip3] triton==3.4.0
[conda] Could not collect