In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6'

In [2]:
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adagrad
from tqdm import tqdm
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import random
from collections import defaultdict
import warnings
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import re
import bisect
import shutil
import json
from time import perf_counter

# warnings.filterwarnings("ignore")
import os
import pickle
from sentence_transformers import SentenceTransformer
import av
from transformers import VideoLlavaForConditionalGeneration, VideoLlavaProcessor
from huggingface_hub import hf_hub_download


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import av
import numpy as np
from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration

def read_video_pyav(container, indices):
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")


Loading checkpoint shards: 100%|████████████████████████████████████████████████████████| 3/3 [00:31<00:00, 10.33s/it]


In [4]:
# video_path = "data/Charades_v1_480/6H78U.mp4"
video_path = "data/Charades_v1_480/0A8CF.mp4"

In [None]:

prompt = "USER: <video>Why is this video funny? ASSISTANT:"

container = av.open(video_path)

# sample uniformly 8 frames from the video
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
clip = read_video_pyav(container, indices)

inputs = processor(text=prompt, videos=clip, return_tensors="pt")

# Generate
generate_ids = model.generate(**inputs, max_length=80)
print(processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])


Expanding inputs for image tokens in Video-LLaVa should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.44.
Expanding inputs for image tokens in Video-LLaVa should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.


In [60]:
class VideoQADataset(Dataset):
    def __init__(
        self,
        json_file,
        video_dir="/data/user_data/gdhanuka/STAR_dataset/Charades_v1_480",
        sampling_fps=4,
        num_frames=8,
        use_fps=True
    ):
        """
        Args:
            json_file (str): Path to the JSON file containing the dataset.
            video_features_dir (str): Directory containing precomputed CLIP features for videos.
            num_frames (int): Number of frames to sample from each video.
        """
        with open(json_file, "rb") as f:
            self.data = pickle.load(f)
        self.video_dir = video_dir
        self.sampling_fps = sampling_fps
        self.num_frames = num_frames
        self.use_fps = use_fps

    def __len__(self):
        return len(self.data)

    def read_video_pyav(self, container, indices):
        """
        Decode the video with PyAV decoder.
        Args:
            container (`av.container.input.InputContainer`): PyAV container.
            indices (`List[int]`): List of frame indices to decode.
        Returns:
            result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
        """
        frames = []
        container.seek(0)
        start_index = indices[0]
        end_index = indices[-1]
        for i, frame in enumerate(container.decode(video=0)):
            if i > end_index:
                break
            if i >= start_index and i in indices:
                frames.append(frame)
        return np.stack([x.to_ndarray(format="rgb24") for x in frames])

    def read_video_pyav2(self, video_path, start, end, num_frames=8):
        """Reads a video for given start-end timestamps interval and uniformly samples 8 frames of it"""
        container = av.open(video_path)
        video = container.streams.get(0)[0]

        av_timestamps = [
            int(packet.pts * video.time_base)
            for packet in container.demux(video)
            if packet.pts is not None
        ]

        av_timestamps.sort()
        start_id = bisect.bisect_left(av_timestamps, start)
        end_id = bisect.bisect_left(av_timestamps, end)

        # in case it is a very short video, lets take a longer duration and sample
        if end_id - start_id < 10:
            end_id += 10
            start_id -= 10

        end_id = min(len(av_timestamps) - 1, end_id)
        start_id = max(1, start_id)

        # We sample 8 frames for tuning following the original paper
        # But we can increase the number of frames for longer videos and check out if it helps performance
        # Change the below "8" to any number of frames you want, and note that more frames -> more computational resources needed
        indices = np.linspace(start_id, end_id, num_frames).astype(int)

        frames = []
        container.seek(0)
        for i, frame in enumerate(container.decode(video=0)):
            if i > end_id:
                break
            if i >= start_id and i in indices:
                frames.append(frame)
        assert (
            len(frames) == num_frames
        ), f"Got {len(frames)} frames but should be {num_frames}. Check the indices: {indices};, start_id: {start_id}, end_id: {end_id}. Len of video is {len(av_timestamps)} frames."
        return np.stack([x.to_ndarray(format="rgb24") for x in frames]), indices

    def read_video_pyav3(self, video_path, start, end, sampling_fps=4):
        """Reads a video clip from start-end timestamps and samples frames at specified FPS"""
        container = av.open(video_path)
        video = container.streams.video[0]
        
        # Calculate number of frames needed based on duration and sampling FPS
        duration = end - start
        num_frames = int(round(sampling_fps * duration))
        num_frames = max(1, num_frames)  # Ensure at least 1 frame

        # Get sorted presentation timestamps
        av_timestamps = [
            int(packet.pts * video.time_base)
            for packet in container.demux(video)
            if packet.pts is not None
        ]
        av_timestamps.sort()

        # Find frame indices bounding our clip
        start_id = bisect.bisect_left(av_timestamps, start)
        end_id = bisect.bisect_left(av_timestamps, end)

        # Expand window for short clips
        if end_id - start_id < 10:
            end_id += 10
            start_id -= 10

        # Clamp to valid range
        end_id = min(len(av_timestamps) - 1, end_id)
        start_id = max(0, start_id)

        # Generate sampling indices
        indices = np.linspace(start_id, end_id, num_frames, dtype=int)

        # Extract frames
        frames = []
        container.seek(0)
        for i, frame in enumerate(container.decode(video=0)):
            if i > end_id:
                break
            if i >= start_id and i in indices:
                frames.append(frame)
        
        assert len(frames) == num_frames, (
            f"Frame sampling failed: Expected {num_frames}, got {len(frames)}. "
            f"Time range: {start}-{end}s ({duration}s), Sampling FPS: {sampling_fps}."
        )
        
        return np.stack([x.to_ndarray(format="rgb24") for x in frames]), indices


    def __getitem__(self, idx):
        
        item = self.data[idx]
        video_id = item["video_id"]
        start, end = item["start"], item["end"]
        question = item["question"]
        question_id = item["question_id"]
        choices = [choice["choice"] for choice in item["choices"]]
        answer_idx = next(
            i
            for i, choice in enumerate(item["choices"])
            if choice["choice"] == item["answer"]
        )

        video_path = os.path.join(self.video_dir, f"{video_id}.mp4")
        start_time = perf_counter()
        if self.use_fps:
            video_frames, frame_idx = self.read_video_pyav3(video_path, start, end, sampling_fps=self.sampling_fps)
        else:
            video_frames, frame_idx = self.read_video_pyav2(video_path, start, end, num_frames=self.num_frames)
        video_frames = torch.from_numpy(video_frames).permute(0, 3, 1, 2).float() # (#frames, channel, h, w)

        all_text_inputs = []
        for choice in choices:
            all_text_inputs.append(f"{question} [SEP] {choice}")
        end = perf_counter()
        return {
            "video_frames": video_frames,  # Video features
            "question": question,
            "video_id": video_id,
            "choices": choices,
            "answer_idx": answer_idx,
            "category": item["question_id"].split("_")[0],  # Question category
            "all_text_inputs": all_text_inputs,
            "data_proc_time": end-start,
            "question_id": question_id,
            "frame_ids": frame_idx
        }


# def collate_fn(batch):
#     """Handles variable-sized video frames using smart padding"""
#     # Separate video frames and metadata
#     videos = [item["video_frames"] for item in batch]
#     questions = [item["question"] for item in batch]
#     choices = [item["choices"] for item in batch]
#     answer_idxs = torch.stack([torch.tensor(item["answer_idx"]) for item in batch])

#     # Pad videos to max dimensions in batch
#     max_frames = max(vid.shape[0] for vid in videos)
#     max_height = max(vid.shape[2] for vid in videos)
#     max_width = max(vid.shape[3] for vid in videos)

#     padded_videos = []
#     for vid in videos:
#         # Pad: (width_left, width_right, height_top, height_bottom, frames_front, frames_back)
#         pad_width = max_width - vid.shape[3]
#         pad_height = max_height - vid.shape[2]
#         pad_frames = max_frames - vid.shape[0]

#         padded = F.pad(vid, (0, pad_width, 0, pad_height, 0, 0, 0, pad_frames))
#         padded_videos.append(padded)

#     return {
#         "video_frames": torch.stack(padded_videos),
#         "question": questions,
#         "choices": choices,
#         "answer_idx": answer_idxs,
#     }


In [56]:
class VideoQAModel:
    def __init__(self, load_in_bits=16):

        device = 'cuda'
        compute_dtype = 'fp16'
        double_quant = True
        quant_type = 'nf4'

        compute_dtype = (torch.float16 if compute_dtype == 'fp16' else (torch.bfloat16 if compute_dtype == 'bf16' else torch.float32))
        
        bnb_model_from_pretrained_args = {}
        if load_in_bits in [4, 8]:
            from transformers import BitsAndBytesConfig
            bnb_model_from_pretrained_args.update(dict(
                device_map={"": device},
                # load_in_4bit=load_in_bits == 4,
                # load_in_8bit=load_in_bits == 8,
                quantization_config=BitsAndBytesConfig(
                    load_in_4bit=load_in_bits == 4,
                    load_in_8bit=load_in_bits == 8,
                    llm_int8_skip_modules=["mm_projector"],
                    llm_int8_threshold=6.0,
                    llm_int8_has_fp16_weight=False,
                    bnb_4bit_compute_dtype=compute_dtype,
                    bnb_4bit_use_double_quant=double_quant,
                    bnb_4bit_quant_type=quant_type # {'fp4', 'nf4'}
                )
            ))

            self.model = VideoLlavaForConditionalGeneration.from_pretrained(
                "LanguageBind/Video-LLaVA-7B-hf",
                torch_dtype=compute_dtype,
                attn_implementation="flash_attention_2",
                **bnb_model_from_pretrained_args
            )
        else:
            self.model = VideoLlavaForConditionalGeneration.from_pretrained(
                "LanguageBind/Video-LLaVA-7B-hf",
                torch_dtype=torch.float16,
                device_map="auto",
                attn_implementation="flash_attention_2",
            ).to("cuda")
        self.processor = VideoLlavaProcessor.from_pretrained(
            "LanguageBind/Video-LLaVA-7B-hf"
        )

    def generate(self, inputs, max_new_tokens=500):
        outputs = self.model.generate(
            **inputs, max_new_tokens=max_new_tokens,
            return_dict_in_generate=True, output_scores=True
        )

        decoded = self.processor.batch_decode(
            outputs.sequences,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False
        )

        print(decoded)
        
        first_token_probs = torch.nn.functional.softmax(outputs.scores[0], dim=-1)
        
        # Get token IDs for numbers 1-4
        token_ids = [self.processor.tokenizer.convert_tokens_to_ids(str(i)) for i in [1,2,3,4]]
        
        # Create probability dictionary for each sample in batch
        prob_list = []
        for batch_idx in range(first_token_probs.shape[0]):
            probs = [
                first_token_probs[batch_idx, token_ids[i]].item()
                for i in range(4)
            ]
            prob_list.append(probs)
        
        logits_list = []
        for batch_idx in range(outputs.scores[0].shape[0]):
            logits = [
                outputs.scores[0][batch_idx, token_ids[i]].item()
                for i in range(4)
            ]
            logits_list.append(logits)


        return decoded, prob_list, logits_list

    def video_qa(self, video_frames, question, choices, max_new_tokens=500):
        choice_with_idx = [f'"{i+1}": {choice}\n' for i, choice in enumerate(choices)]
        prompt = f"USER: <video>\n {question} \n {choice_with_idx} Answer with the option's index from the given choices directly. \n ASSISTANT: "
        inputs = self.processor(
            text=prompt, videos=video_frames, return_tensors="pt", max_length=4096
        ).to("cuda")
        decoded, probs, logits = self.generate(inputs, max_new_tokens=max_new_tokens)
        return decoded[0], probs[0], prompt, logits[0]

    def video_qa_batch(self, video_batch, questions, choices_batch):
        prompts = []
        for q, choices in zip(questions, choices_batch):
            opts = "\n".join([f"{i+1}: {c}" for i, c in enumerate(choices)])
            prompts.append(
                f"USER: <video>\n According to the video choose the correct answer, {questions} \n {opts} ASSISTANT: "
            )

        inputs = self.processor(
            text=prompts,
            videos=[v for v in video_batch],  # Process full batch
            return_tensors="pt",
            padding=True,
            truncation=True,
        ).to("cuda", torch.float16)

        outputs = self.model.generate(**inputs, max_new_tokens=20)
        return self.processor.batch_decode(outputs, skip_special_tokens=True)


In [40]:
import torch
import re
from typing import List, Dict, Any, Optional, Union
import numpy as np
from time import perf_counter
import json
from tqdm import tqdm
import os


class VideoOfThoughtPredictor:
    def __init__(self, video_llava_model):
        """
        Initialize the Video-of-Thought predictor with a VideoLLAVA model.
        
        Args:
            video_llava_model: An instance of the VideoLLAVA model class
        """
        self.model = video_llava_model
        
    def _generate_response(self, video_frames, prompt, max_new_tokens=100):
        """
        Generate a response from the VideoLLAVA model.
        
        Args:
            video_frames: Video frame tensors
            prompt: Text prompt
            max_new_tokens: Maximum number of tokens to generate
            
        Returns:
            Generated text response
        """
        inputs = self.model.processor(
            text=prompt, 
            videos=video_frames, 
            return_tensors="pt", 
            max_length=4096
        ).to("cuda")
        
        # Use more controlled generation parameters
        outputs = self.model.model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            # do_sample=False,  # Use greedy decoding for more consistent outputs
            # temperature=0.1,   # Lower temperature for more focused responses
            # num_beams=1,
            # early_stopping=True,
            # pad_token_id=self.model.processor.tokenizer.pad_token_id,
            # eos_token_id=self.model.processor.tokenizer.eos_token_id
        )

        print(outputs)
        
        decoded = self.model.processor.batch_decode(
            outputs,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        
        # Extract just the assistant's response, removing the prompt
        response = decoded[0]
        if "ASSISTANT:" in response:
            response = response.split("ASSISTANT:")[-1].strip()
            
        return response
    
    def step_1_identify_targets(self, video_frames, question, is_multi_choice=True):
        """
        Step 1: Task Definition and Target Identification
        
        Args:
            video_frames: Video frame tensors
            question: The question text
            is_multi_choice: Whether the question is multiple choice
            
        Returns:
            The identified targets in the video relevant to the question
        """
        if is_multi_choice:
            task_definition = "You are an expert in video analysis."
        else:
            task_definition = "You are an expert in video analysis."
        
        prompt = f"USER: <video>\n{task_definition}\n\nGiven the question: \"{question}\", what are the key objects, people, or elements in the video that need to be tracked to answer this question?\n\nProvide a concise list of the key targets.\nASSISTANT:"
        
        response = self._generate_response(video_frames, prompt, max_new_tokens=100)
        return response
    
    def step_2_object_description(self, video_frames, targets, question):
        """
        Step 2: Object Description (adapted from Object Tracking in the original paper)
        
        Args:
            video_frames: Video frame tensors
            targets: The identified targets from step 1
            question: The original question
            
        Returns:
            Description of the targets throughout the video
        """
        prompt = f"USER: <video>\nDescribe in detail the following elements that are relevant to answering the question \"{question}\":\n\n{targets}\n\nFocus on their appearance, movement, and interactions in the video.\nASSISTANT:"
        
        response = self._generate_response(video_frames, prompt, max_new_tokens=150)
        return response
    
    def step_3_action_analysis(self, video_frames, object_descriptions, question):
        """
        Step 3: Action Analysis
        
        Args:
            video_frames: Video frame tensors
            object_descriptions: The object descriptions from step 2
            question: The original question
            
        Returns:
            Analysis of actions and implications
        """
        prompt = f"USER: <video>\nBased on the question \"{question}\" and these observations:\n\n{object_descriptions}\n\nAnalyze what actions are occurring in the video, their sequence, and their implications. Include both direct observations and reasonable inferences.\nASSISTANT:"
        
        response = self._generate_response(video_frames, prompt, max_new_tokens=200)
        return response
    
    def step_4_answer_scoring(self, video_frames, question, choices, action_analysis):
        """
        Step 4: Answer Scoring and Ranking for multi-choice questions
        
        Args:
            video_frames: Video frame tensors
            question: The question text
            choices: List of answer choices
            action_analysis: The action analysis from step 3
            
        Returns:
            Final answer with scores
        """
        # First, score each choice individually
        scores_and_rationales = []
        
        for i, choice in enumerate(choices):
            prompt = f"USER: <video>\nQuestion: {question}\nCandidate answer: {choice}\n\nBased on the video and this analysis:\n{action_analysis}\n\nRate the likelihood of this answer being correct (1-10) and explain why.\nASSISTANT:"
            
            response = self._generate_response(video_frames, prompt, max_new_tokens=150)
            scores_and_rationales.append(response)
        
        # Now do the final ranking and selection
        prompt = f"USER: <video>\nFor the question: \"{question}\", here are the ratings for each answer choice:\n\n"
        
        for i, (choice, rationale) in enumerate(zip(choices, scores_and_rationales)):
            prompt += f"Option {i+1}: {choice}\nRating: {rationale}\n\n"
        
        prompt += "Based on these ratings, which answer is most likely correct and why? Give the answer number.\nASSISTANT:"
        
        ranking_response = self._generate_response(video_frames, prompt, max_new_tokens=100)
        
        # Extract the final answer index using regex
        try:
            answer_number_match = re.search(r'(\d+)', ranking_response)
            if answer_number_match:
                answer_number = int(answer_number_match.group(1))
                # Adjust to 0-based indexing
                answer_index = answer_number - 1
                if 0 <= answer_index < len(choices):
                    final_answer = choices[answer_index]
                else:
                    final_answer = ranking_response  # Use full response if index out of range
            else:
                final_answer = ranking_response
        except:
            final_answer = ranking_response
            
        return final_answer, ranking_response, scores_and_rationales
    
    def step_5_answer_verification(self, video_frames, question, final_answer, action_analysis):
        """
        Step 5: Answer Verification
        
        Args:
            video_frames: Video frame tensors
            question: The question text
            final_answer: The final answer from step 4
            action_analysis: The action analysis from step 3
            
        Returns:
            Verification of the answer
        """
        prompt = f"USER: <video>\nQuestion: {question}\nSelected answer: {final_answer}\n\nBased on the video evidence and this analysis:\n{action_analysis}\n\nVerify whether this answer is correct. Provide a final verdict (correct/incorrect) with justification.\nASSISTANT:"
        
        response = self._generate_response(video_frames, prompt, max_new_tokens=150)
        return response
    
    def video_qa_reasoning(self, video_frames, question, choices, is_multi_choice=True, output_intermediate_steps=False):
        """
        Complete video QA reasoning process using the Video-of-Thought approach
        
        Args:
            video_frames: Video frame tensors
            question: The question text
            choices: List of answer choices
            is_multi_choice: Whether the question is multiple choice
            output_intermediate_steps: Whether to output intermediate reasoning steps
            
        Returns:
            Final answer and optionally intermediate steps
        """
        print("Step 1: Identifying targets...")
        targets = self.step_1_identify_targets(video_frames, question, is_multi_choice)

        print('target:',targets)
        
        print("Step 2: Describing objects...")
        object_descriptions = self.step_2_object_description(video_frames, targets, question)

        print('object_descriptions:',object_descriptions)
        
        print("Step 3: Analyzing actions...")
        action_analysis = self.step_3_action_analysis(video_frames, object_descriptions, question)

        print('action_analysis:',action_analysis)
        
        print("Step 4: Scoring and ranking answers...")
        final_answer, ranking_response, scores = self.step_4_answer_scoring(
            video_frames, question, choices, action_analysis
        )

        print('final_answer:',final_answer)
        
        print("Step 5: Verifying answer...")
        verification = self.step_5_answer_verification(
            video_frames, question, final_answer, action_analysis
        )

        print('verification:',verification)
        
        # Format the final result
        if is_multi_choice:
            # Try to extract the answer index
            answer_number = "Unknown"
            try:
                answer_number_match = re.search(r'(\d+)', ranking_response)
                if answer_number_match:
                    answer_number = answer_number_match.group(1)
            except:
                pass
                
            final_result = f"Answer: {answer_number}\nRationale: {verification}"
        else:
            final_result = f"Answer: {final_answer}\nRationale: {verification}"
        
        if output_intermediate_steps:
            return {
                "targets": targets,
                "object_descriptions": object_descriptions,
                "action_analysis": action_analysis,
                "scores": scores,
                "ranking": ranking_response,
                "final_answer": final_answer,
                "verification": verification,
                "final_result": final_result
            }
        else:
            return final_result
    
    def video_qa_direct(self, video_frames, question, choices=None, max_new_tokens=100):
        """
        Standard video QA without the step-by-step reasoning process
        
        Args:
            video_frames: Video frame tensors
            question: The question text
            choices: List of answer choices or None for open-ended questions
            max_new_tokens: Maximum number of tokens to generate
            
        Returns:
            Direct answer without step-by-step reasoning
        """
        if choices:
            # Format multiple-choice question
            choice_with_idx = [f'"{i+1}": {choice}\n' for i, choice in enumerate(choices)]
            prompt = f"USER: <video>\n {question} \n {choice_with_idx} Answer with the option's index from the given choices directly. \n ASSISTANT: "
        else:
            # Open-ended question
            prompt = f"USER: <video>\n {question} \n Answer directly based on what you see in the video. \n ASSISTANT: "
        
        response = self._generate_response(video_frames, prompt, max_new_tokens)
        return response


# Example usage:
# video_llava_model = VideoQAModel(load_in_bits=4)
# predictor = VideoOfThoughtPredictor(video_llava_model)
# 
# # For multi-choice question:
# result = predictor.video_qa_reasoning(
#     video_frames=frames,
#     question="What is the person doing in the video?",
#     choices=["Cooking", "Dancing", "Reading", "Exercising"],
#     output_intermediate_steps=True
# )

In [43]:
# Initialize with your VideoLLAVA model
video_llava_model = VideoQAModel(load_in_bits=16)

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████| 3/3 [00:36<00:00, 12.23s/it]


In [61]:
# val_pkl = "/data/user_data/gdhanuka/STAR_dataset/STAR_val.pkl"
val_pkl = "data/STAR_val.pkl"
# video_dir = "/data/user_data/gdhanuka/STAR_dataset/Charades_v1_480"
video_dir = "data/Charades_v1_480"

# File paths
# results_file = "/home/gdhanuka/STAR_Benchmark/analysis/video_llava_results2.jsonl"
results_file = "analysis/video_llava_4_frames_results.jsonl"
# final_accuracy_file = "/home/gdhanuka/STAR_Benchmark/analysis/video_llava_final_accuracy2.txt"
final_accuracy_file = "analysis/video_llava_4_frames_final_accuracy.txt"

device = "cuda" if torch.cuda.is_available() else "cpu"
dataset = VideoQADataset(val_pkl, video_dir=video_dir, sampling_fps=4, num_frames=8, use_fps=False)
# batched inference not working!!
dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,  # disable to easily reproduce
    num_workers=4,
    pin_memory=True,
    # collate_fn=collate_fn,
)
category_correct = defaultdict(int)
category_total = defaultdict(int)


In [62]:
batch = next(iter(dataloader))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [77]:

from IPython.display import HTML
import os


def display_video(video_id, video_dir="data/Charades_v1_480"):
    video_path = os.path.join(video_dir, f"{video_id}.mp4")
    
    html_content = f'''
    <video width="320" height="240" controls>
      <source src="{video_path}" type="video/mp4">
    </video>
    '''
    
    return HTML(html_content)

In [63]:
batch.keys()

dict_keys(['video_frames', 'question', 'video_id', 'choices', 'answer_idx', 'category', 'all_text_inputs', 'data_proc_time', 'question_id', 'frame_ids'])

In [47]:
video_frames = batch["video_frames"][0].to(device)
question = batch["question"][0]
choices = batch["choices"]

print(question)
print(choices)

Which object was tidied up by the person?
[['The closet/cabinet.'], ['The blanket.'], ['The clothes.'], ['The table.']]


In [49]:
predictor = VideoOfThoughtPredictor(video_llava_model)

In [78]:
video_id = batch['video_id'][0]

In [79]:
display_video(video_id)

In [91]:
def video_qa(self, video_frames, question, choices, max_new_tokens=500):
    choice_with_idx = [f'"{i+1}": {choice}\n' for i, choice in enumerate(choices)]
    # prompt = f"USER: <video>\n {question} \n {choice_with_idx} Answer the question. Think step by step. \n ASSISTANT: "
    prompt = "USER: <video>Why is this video funny? ASSISTANT:"
    inputs = self.processor(
        text=prompt, videos=video_frames, return_tensors="pt"
    ).to("cuda")
    
    ouputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens)

    print(ouputs)
    
    return ouputs


In [92]:
video_qa(video_llava_model, video_frames, question, choices)

tensor([[    1,  3148,  1001, 29901, 29871, 32001, 11008,   338,   445,  4863,
          2090,  1460, 29973,   319,  1799,  9047, 13566, 29901,   450, 29892,
           322, 29892,   322,   322,   322,   322,   322, 29973,   322,   322,
           322,   322, 29892,   322, 29892,   322, 29892,   322, 29892,   518,
           322, 29892,   518,   322, 29892,   518,   322, 29892,   322, 29892,
           322, 29892,   322, 29892,   322, 29892,   322, 29892,   322, 29892,
           322, 29892,   322, 29892,   322, 29892,   322, 29892,   322, 29892,
           322, 29892,   322, 29892,   322, 29892,   322, 29892,   322, 29892,
           322, 29892,   322, 29892,   322, 29892,   518,   322, 29892,   322,
         29892,   322, 29892,   518,   322, 29892,   322, 29892,   322, 29892,
           518,   322, 29892,   518,   322, 29892,   518,   322, 29892,   518,
           322, 29892,   322, 29892,   322, 29892,   518,   518,   518,   518,
           322, 29892,   322, 29892,   518,   322, 2

tensor([[    1,  3148,  1001, 29901, 29871, 32001, 11008,   338,   445,  4863,
          2090,  1460, 29973,   319,  1799,  9047, 13566, 29901,   450, 29892,
           322, 29892,   322,   322,   322,   322,   322, 29973,   322,   322,
           322,   322, 29892,   322, 29892,   322, 29892,   322, 29892,   518,
           322, 29892,   518,   322, 29892,   518,   322, 29892,   322, 29892,
           322, 29892,   322, 29892,   322, 29892,   322, 29892,   322, 29892,
           322, 29892,   322, 29892,   322, 29892,   322, 29892,   322, 29892,
           322, 29892,   322, 29892,   322, 29892,   322, 29892,   322, 29892,
           322, 29892,   322, 29892,   322, 29892,   518,   322, 29892,   322,
         29892,   322, 29892,   518,   322, 29892,   322, 29892,   322, 29892,
           518,   322, 29892,   518,   322, 29892,   518,   322, 29892,   518,
           322, 29892,   322, 29892,   322, 29892,   518,   518,   518,   518,
           322, 29892,   322, 29892,   518,   322, 2

In [50]:
# For a multiple-choice question
result = predictor.video_qa_reasoning(
    video_frames=video_frames,
    question=question,
    choices=choices,
    is_multi_choice=True,
    output_intermediate_steps=True  # Set to True to see all reasoning steps
)


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Step 1: Identifying targets...
tensor([[    1,  3148,  1001, 29901, 29871, 32001,    13,  3492,   526,   385,
         17924,   297,  4863,  7418, 29889,    13,    13, 29954,  5428,   278,
          1139, 29901,   376,  8809,   436,  1203,   471, 10668,  1000,   701,
           491,   278,  2022, 29973,   613,   825,   526,   278,  1820,  3618,
         29892,  2305, 29892,   470,  3161,   297,   278,  4863,   393,   817,
           304,   367,  5702,   287,   304,  1234,   445,  1139, 29973,    13,
            13,  1184, 29894,   680,   263,  3022,   895,  1051,   310,   278,
          1820, 22525, 29889,    13, 22933,  9047, 13566, 29901,   450,   322,
           322,   322,   322,   322,   322,     2]], device='cuda:0')
target: The and and and and and and
Step 2: Describing objects...
tensor([[    1,  3148,  1001, 29901, 29871, 32001,    13,  4002, 29581,   297,
          9493,   278,  1494,  3161,   393,   526,  8018,   304, 22862,   278,
          1139,   376,  8809,   436,  1203,

KeyboardInterrupt: 

In [9]:

# For an open-ended question
result = predictor.video_qa_reasoning(
    video_frames=frames,
    question="What is the person doing in the video?",
    is_multi_choice=False
)

NameError: name 'frames' is not defined