In [34]:
import numpy as np
from PIL import Image
import requests
import av
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, LlavaNextVideoForConditionalGeneration


def read_video_pyav(container, indices):
    """
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    """
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


model = LlavaNextVideoForConditionalGeneration.from_pretrained(
    "llava-hf/LLaVA-NeXT-Video-7B-hf", device_map="auto", torch_dtype=torch.float16
)
processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")

prompt = "USER: <video>\nWhy is this video funny? ASSISTANT:"
video_path = hf_hub_download(
    repo_id="raushan-testing-hf/videos-test",
    filename="sample_demo_1.mp4",
    repo_type="dataset",
)
container = av.open(video_path)

# sample uniformly 8 frames from the video (model was trained with 32 frames per video, but this video is short)
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
clip = read_video_pyav(container, indices)
inputs_video = processor(text=prompt, videos=clip, return_tensors="pt").to(model.device)

inputs_video = dict(inputs_video)
# print(inputs_video.input_ids[:, -10:])
# load an image to generate from an image
# prompt = "USER:<image>\nWhat is shown in this image? ASSISTANT:"
# url = "https://www.ilankelman.org/stopsigns/australia.jpg"
# image = Image.open(requests.get(url, stream=True).raw)
# inputs_image = processor(text=prompt, images=image, return_tensors="pt").to(
#     model.device
# )


def update_positional_and_cache_ids(inputs_video, first_input=True):
    """
    Update the positional ids of the video.
    Args:
        inputs_video (`Dict`): Dictionary containing the input tensors.
        first_input (`bool`): Whether this is the first input or not in auto-regressive generation.
    Returns:
        inputs_video (`Dict`): Updated dictionary with new positional ids.
    """
    device = inputs_video["input_ids"].device
    inputs_video = dict(inputs_video)
    if first_input:
        batch_size, num_tokens = inputs_video["input_ids"].shape[:2]
        ids = torch.arange(num_tokens, device=device)
        inputs_video["cache_position"] = ids
        inputs_video["postional_ids"] = ids.expand(batch_size, num_tokens)
    else:
        batch_size = inputs_video["input_ids"].shape[0]
        ids = torch.max(inputs_video["postional_ids"]) + 1
        inputs_video["cache_position"] = ids.expand(batch_size)
        inputs_video["postional_ids"] = ids.expand(batch_size, 1)
    
    return inputs_video


def prepare_inputs_for_generation(inputs_video, predicted_outputs=None):
    """
    Prepare the inputs for generation.
    Args:
        inputs_video (`Dict`): Dictionary containing the input tensors.
        predicted_outpur (`Dict | None`): The predicted outputs from the model with `inputs_video`. 
            Contains `logits` and `past_key_values`. None if this is the first input.
    Returns:
        inputs_video (`Dict`): Updated dictionary with new input tensors.
    """
    device = inputs_video["input_ids"].device
    inputs_video = dict(inputs_video)
    
    if predicted_outputs is None:
        inputs_video = update_positional_and_cache_ids(inputs_video, first_input=True)
        inputs_video["past_key_values"] = None
        inputs_video["logits_to_keep"] = 1
        inputs_video["use_cache"] = True
    else:
        inputs_video["input_ids"] = predicted_outputs["logits"].argmax(dim=-1)
        inputs_video["attention_mask"] = torch.cat(
            [
                torch.ones(
                    (inputs_video["attention_mask"].shape[0], 1), device=device
                ),
                inputs_video["attention_mask"],
            ],
            dim=1,
        )
        inputs_video = update_positional_and_cache_ids(inputs_video, first_input=False)
        inputs_video["past_key_values"] = predicted_outputs["past_key_values"]
        inputs_video["logits_to_keep"] = 1
        inputs_video["use_cache"] = True
        inputs_video["pixel_values_videos"] = None
    return inputs_video

# print(inputs_video)
# Generate from video
print(type(inputs_video))
generated_output = model.generate(
    **inputs_video, max_new_tokens=200, output_logits=True, return_dict_in_generate=True
)
processor.batch_decode(
    generated_output.sequences,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False,
)[0]


# max_new_tokens = 50
# with torch.no_grad():
#     # inputs_video = processor(text=prompt, videos=clip, return_tensors="pt").to(model.device)
#     # inputs_video.input_ids = inputs_video.input_ids[:, :1]
#     # inputs_video.attention_mask = inputs_video.attention_mask[:, :1]
#     # inputs_video.video_mask = inputs_video.video_mask[:, :1]
#     # inputs_video.video_attention_mask = inputs_video.video_attention_mask[:, :1]
#     # inputs_video.video_position_ids = inputs_video.video_position_ids[:, :1]
#     for _ in range(max_new_tokens):
#         # print(inputs_video.input_ids.shape)
#         output = model(
#             input_ids=inputs_video.input_ids,
#             attention_mask=inputs_video.attention_mask,
#             pixel_values_videos=inputs_video.pixel_values_videos,
#             past_key_values=inputs_video.past_key_values,
#             logits_to_keep=1,
#         )
#         # output = model(logits_to_keep=1, **inputs_video)
#         predicted_ids = output.logits.argmax(-1)
#         print(predicted_ids)
#         # inputs_video.input_ids = torch.cat([inputs_video.input_ids, predicted_ids], dim=1)

#         # inputs_video = inputs_video
#         inputs_video.input_ids = predicted_ids
#         inputs_video.attention_mask = torch.cat(
#             [
#                 torch.ones(
#                     (inputs_video.attention_mask.shape[0], 1), device=model.device
#                 ),
#                 inputs_video.attention_mask,
#             ],
#             dim=1,
#         )
#         inputs_video.past_key_values = output.past_key_values
#         inputs_video.pixel_values_videos = None
#         positional_ids =
# inputs_video.

# # Generate from image
# generate_ids = model.generate(**inputs_image, max_new_tokens=50)
# processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.33s/it]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


<class 'dict'>
{'cache_position': tensor([   0,    1,    2,  ..., 1167, 1168, 1169], device='cuda:0'), 'past_key_values': <transformers.cache_utils.DynamicCache object at 0x725543605070>, 'input_ids': tensor([[    1,  3148,  1001,  ...,  9047, 13566, 29901]], device='cuda:0'), 'inputs_embeds': None, 'position_ids': tensor([[   0,    1,    2,  ..., 1167, 1168, 1169]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], device='cuda:0'), 'logits_to_keep': 1, 'use_cache': True, 'pixel_values': None, 'pixel_values_videos': tensor([[[[[-1.2813e+00, -1.2521e+00, -1.2521e+00,  ..., -1.0039e+00,
            -9.8935e-01, -9.6015e-01],
           [-1.2375e+00, -1.2521e+00, -1.2521e+00,  ..., -1.0039e+00,
            -9.8935e-01, -9.6015e-01],
           [-1.2521e+00, -1.2521e+00, -1.2521e+00,  ..., -1.0039e+00,
            -9.8935e-01, -9.6015e-01],
           ...,
           [ 1.6393e-01,  2.2232e-01,  2.9531e-01,  ..., -7.2658e-01,
            -7.2658e-01, -7.1198e-01],
    

"USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and endearing nature of the situation. The baby is wearing glasses and appears to be reading a book, which is a humorous and endearing sight because babies are typically not expected to be able to read at such a young age. The glasses add a touch of whimsy and the baby's expression and actions suggest that they are deeply engrossed in the book, which is a playful and amusing portrayal of a child's curiosity and imagination. The video captures a moment of innocence and playfulness that can be seen as cute and entertaining to viewers."

In [32]:
max_new_tokens = 500
with torch.no_grad():
    for i in range(max_new_tokens):
        if i == 0:
            inputs_video = prepare_inputs_for_generation(inputs_video, None)
        else:
            inputs_video = prepare_inputs_for_generation(inputs_video, output)
        # print(inputs_video)
        output = model(**inputs_video)

{'input_ids': tensor([[    1,  3148,  1001,  ...,  9047, 13566, 29901]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], device='cuda:0'), 'pixel_values_videos': tensor([[[[[-1.2813e+00, -1.2521e+00, -1.2521e+00,  ..., -1.0039e+00,
            -9.8935e-01, -9.6015e-01],
           [-1.2375e+00, -1.2521e+00, -1.2521e+00,  ..., -1.0039e+00,
            -9.8935e-01, -9.6015e-01],
           [-1.2521e+00, -1.2521e+00, -1.2521e+00,  ..., -1.0039e+00,
            -9.8935e-01, -9.6015e-01],
           ...,
           [ 1.6393e-01,  2.2232e-01,  2.9531e-01,  ..., -7.2658e-01,
            -7.2658e-01, -7.1198e-01],
           [ 1.0553e-01,  1.6393e-01,  2.6612e-01,  ..., -7.1198e-01,
            -7.1198e-01, -6.9738e-01],
           [ 4.7139e-02,  1.0553e-01,  1.7853e-01,  ..., -6.9738e-01,
            -6.9738e-01, -6.8278e-01]],

          [[-1.3769e+00, -1.3469e+00, -1.3469e+00,  ..., -1.1818e+00,
            -1.1668e+00, -1.1368e+00],
           [-1.3319e+00, -1.3469e+

In [36]:
processor.tokenizer(
    ["USER: <video>\nWhy is this video funny? ASSISTANT:", "Hello there!"],
    max_length=50,
    truncation=True,
    padding="max_length",
    padding_side="right",
    add_special_tokens=True,
)

{'input_ids': [[1, 3148, 1001, 29901, 29871, 32000, 13, 11008, 338, 445, 4863, 2090, 1460, 29973, 319, 1799, 9047, 13566, 29901, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 15043, 727, 29991, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}

In [19]:
from torch import nn
import torch

# Example of target with class indices
loss = nn.CrossEntropyLoss(reduction="none")
input = torch.randn(3, 5,  100, requires_grad=True)
target = torch.empty(3, 100, dtype=torch.long).random_(5)
output = loss(input, target)
print(output.shape)
output.mean().backward()

# Example of target with class probabilities
input = torch.randn(3,  5, 100, requires_grad=True)
target = torch.randn(3, 5, 100).softmax(dim=1)
output = loss(input, target)
print(output.shape)
output.mean().backward()

torch.Size([3, 100])
torch.Size([3, 100])


In [20]:
# "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and endearing situation of a young child, who appears to be a baby or toddler, attempting to read a book. The child's small size and the fact that they are reading a book"