Skip to content

llava model compile output regression caused by check_model_inputs #40964

@jiqing-feng

Description

@jiqing-feng

System Info

torch 2.10.0.dev20250914+cpu
transformers 4.57.0.dev0

Who can help?

@zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Run the following code on CPU:

import av
import cv2
import torch
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import LlavaProcessor, LlavaForConditionalGeneration


model_id = "llava-hf/llava-interleave-qwen-7b-hf"
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16)

def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def sample_frames(path, num_frames):
    video = cv2.VideoCapture(path)
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    interval = total_frames // num_frames
    frames = []
    for i in range(total_frames):
        ret, frame = video.read()
        pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        if not ret:
            continue
        if i % interval == 0:
            frames.append(pil_img)
    video.release()
    return frames[:num_frames]


# define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image", "video") 
conversation = [
    {

        "role": "user",
        "content": [
            {"type": "text", "text": "Why is this video funny?"},
            {"type": "video"},
            ],
    },
]

prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
container = av.open(video_path)

# sample uniformly 8 frames from the video, can sample more for longer videos
videos = sample_frames(video_path, 6)
user_prompt = conversation[0]["content"][0]["text"]
toks = "<image>" * 6
prompt = (
    "<|im_start|>user"
    + toks
    + f"\n{user_prompt}<|im_end|><|im_start|>assistant"
)
inputs = processor(text=prompt, images=videos, return_tensors="pt").to(
    model.device, model.dtype
)

generation_config = model.generation_config
generation_config.do_sample = False
generation_config.use_cache = True
generation_config.temperature = 1.0
generation_config.max_new_tokens = 10
generation_config.min_new_tokens = 10
generation_config.top_p = 1.0
generation_config.cache_implementation = "static"


output = model.generate(**inputs, generation_config=generation_config)
print("eager model output:")
print(processor.decode(output[0][2:], skip_special_tokens=True))
print("\n")

model.forward = torch.compile(model.forward)
output = model.generate(**inputs, generation_config=generation_config)
print("compile model output:")
print(processor.decode(output[0][2:], skip_special_tokens=True))

Expected behavior

Output before the PR #40342

eager model output:
Why is this video funny?assistant The video is humorous because the baby is wearing oversized

compile model output:
Why is this video funny?assistant The video is humorous because the baby is wearing oversized

Output after the PR #40342

eager model output:
Why is this video funny?assistant The video is humorous because the baby is wearing oversized

compile model output:
Why is this video funny?assistant The video is humorous because it shows a baby attempting

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions