# Video Encoder testing

## Preprocessor

### Encoder class

In [None]:
import transformers
from transformers import Qwen2VLProcessor

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.jit

from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig
from transformers.models.qwen2_5_vl.modular_qwen2_5_vl import (
    Qwen2_5_VLPreTrainedModel ,
    Qwen2_5_VLVisionConfig,
    Qwen2_5_VisionPatchEmbed,
    Qwen2_5_VisionRotaryEmbedding,
    Qwen2_5_VLVisionBlock,
    Qwen2_5_VLPatchMerger
)

class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
    config_class = Qwen2_5_VLVisionConfig
    _no_split_modules = ["Qwen2_5_VLVisionBlock"]

    def __init__(self, config: Qwen2_5_VLVisionConfig, *inputs, **kwargs) -> None:
        super().__init__(config, *inputs, **kwargs)
        self.spatial_merge_size = config.spatial_merge_size
        self.patch_size = config.patch_size
        self.fullatt_block_indexes = config.fullatt_block_indexes
        self.window_size = config.window_size
        self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size

        self.patch_embed = Qwen2_5_VisionPatchEmbed(
            patch_size=config.patch_size,
            temporal_patch_size=config.temporal_patch_size,
            in_channels=config.in_channels,
            embed_dim=config.hidden_size,
        )

        head_dim = config.hidden_size // config.num_heads
        self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)

        self.blocks = nn.ModuleList(
            [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
        )
        self.merger = Qwen2_5_VLPatchMerger(
            dim=config.out_hidden_size,
            context_dim=config.hidden_size,
            spatial_merge_size=config.spatial_merge_size,
        )
        self.gradient_checkpointing = False

    def rot_pos_emb(self, grid_thw):
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def get_window_index(self, grid_thw):
        window_index: list = []
        cu_window_seqlens: list = [0]
        window_index_id = 0
        vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size

        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.spatial_merge_size,
                grid_w // self.spatial_merge_size,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
            cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)

        return window_index, cu_window_seqlens

    def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
        """
        Args:
            hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
                The final hidden states of the model.
            grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
                The temporal, height and width of feature shape of each image in LLM.

        Returns:
            `torch.Tensor`: hidden_states.
        """
        hidden_states = self.patch_embed(hidden_states)
        rotary_pos_emb = self.rot_pos_emb(grid_thw)
        window_index, cu_window_seqlens = self.get_window_index(grid_thw)
        cu_window_seqlens = torch.tensor(
            cu_window_seqlens,
            device=hidden_states.device,
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

        seq_len, _ = hidden_states.size()
        hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        hidden_states = hidden_states[window_index, :, :]
        hidden_states = hidden_states.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        rotary_pos_emb = rotary_pos_emb[window_index, :, :]
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())

        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            dim=0,
            # Select dtype based on the following factors:
            #  - FA2 requires that cu_seqlens_q must have dtype int32
            #  - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
            # See https://github.com/huggingface/transformers/pull/34852 for more information
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

        for layer_num, blk in enumerate(self.blocks):
            if layer_num in self.fullatt_block_indexes:
                cu_seqlens_now = cu_seqlens
            else:
                cu_seqlens_now = cu_window_seqlens
            if self.gradient_checkpointing and self.training:
                hidden_states = self._gradient_checkpointing_func(
                    blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings
                )
            else:
                hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings)
        

        hidden_states = self.merger(hidden_states)

        reverse_indices = torch.argsort(window_index)
        hidden_states = hidden_states[reverse_indices, :]


        return hidden_states



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

### Video info

In [None]:
# Using OpenCV
import cv2

# These paths are on Sockeye
video_path_1 = "/scratch/st-jzhu71-1/ewong25/my_jupyter/How2Sign_Clips/val/bIUmw2DVW7Q_9-3-rgb_front.mp4"
video_path_2 = "/scratch/st-jzhu71-1/ewong25/my_jupyter/How2Sign_Clips/val/a4Nxq0QV_WA_4-5-rgb_front.mp4"
video_path_3 = "/scratch/st-jzhu71-1/ewong25/my_jupyter/How2Sign_Clips/val/a5yNwUSiYpA_11-3-rgb_front.mp4"

cap = cv2.VideoCapture(video_path_1)

# Get frame count
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

# Get FPS
fps = cap.get(cv2.CAP_PROP_FPS)

# Calculate duration (seconds)
duration = frame_count / fps


# Example usage
print(f"Frame count: {frame_count}")
print(f"fps: {fps}")
print(f"Duration: {duration} seconds")

In [None]:
# Using av since it was easier for me
import av
import numpy as np


container = av.open(video_path_1)

video_stream = next(stream for stream in container.streams if stream.type == "video")

frames_list = []
for i, frame in enumerate(container.decode(video=0)):  # 0 -> first video stream
    # frame is in PyAV's own format, so you often convert it to a NumPy array
    frames_list.append(frame.to_image())  # shape: (height, width, 3)

In [None]:
# Frame list contains all the frames from the video
frames_list[0].size

### QWen video pre-processing

In [None]:
# QWen video preprocess

import numpy as np
from qwen_vl_utils import process_vision_info
video_path_1 = "/scratch/st-jzhu71-1/ewong25/my_jupyter/How2Sign_Clips/val/bIUmw2DVW7Q_9-3-rgb_front.mp4"
video_path_2 = "/scratch/st-jzhu71-1/ewong25/my_jupyter/How2Sign_Clips/val/a4Nxq0QV_WA_4-5-rgb_front.mp4"
video_path_3 = "/scratch/st-jzhu71-1/ewong25/my_jupyter/How2Sign_Clips/val/a5yNwUSiYpA_11-3-rgb_front.mp4"

# Batch number is not explicitly specified
messages = [{"role": "user", 
             "content": [{
                    "video": video_path_1,
                    "min_pixels": 4096, #128*128
                    "max_pixels": 16384,  # 224*224
                    "fps": 4
                }]
              },
            {"role": "user", 
             "content": [{
                    "video": video_path_3,
                    "min_pixels": 4096, #128*128
                    "max_pixels": 16384,  # 224*224
                    "fps": 4
                }]
              },
            #{"role": "user", "content": [{"video": video_path_3}]}...]
            ]

image_inputs, video_inputs, video_kwargs = process_vision_info([messages], return_video_kwargs=True)

In [None]:
# video_kwargs contains the fps after pre-processing
video_kwargs

In [None]:
# pre-processed output shape
video_inputs[0].shape

#### helper fucntions for data inspection

In [None]:
# Function to inspect the image/frames after preprocessing

import torchvision.transforms as T

# Convert the pixel representation number from 0-255 to 0-1
adjusted_video_inputs = video_inputs[0]/255.0

to_pil = T.ToPILImage()
processed_images = []

for i in range(adjusted_video_inputs.shape[0]):
    # Convert each frame tensor to PIL Image
    processed_image = to_pil(adjusted_video_inputs[i])
    processed_images.append(processed_image)


In [None]:
# function to display entire frame list
from PIL import Image

def display_images_side_by_side(image_list):
    # Calculate the total width and the maximum height
    total_width = sum(img.width for img in image_list)
    max_height = max(img.height for img in image_list)
    
    # Create a new blank image with the calculated dimensions
    result = Image.new('RGB', (total_width, max_height))
    
    # Paste each image next to each other
    x_offset = 0
    for img in image_list:
        result.paste(img, (x_offset, 0))
        x_offset += img.width
    
    return result

# Assuming processed_images is your list of PIL images
combined_image = display_images_side_by_side(processed_images)

# Display the combined image
combined_image.show()

#### QWen vision processor

In [None]:
from transformers import Qwen2VLProcessor

processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")

inputs = processor(text = "", videos=video_inputs, padding=True, return_tensors="pt")
inputs

In [None]:
inputs["pixel_values_videos"].shape

### Run Vision Encoder

In [None]:
from transformers import AutoModel

from transformers import AutoConfig
v_config = Qwen2_5_VLVisionConfig.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", 
                                                  torch_dtype = "bfloat16",
                                                  #spatial_merge_size = 4
                                                 )

#config = AutoConfig.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
#config

#model = Qwen2_5_VisionTransformerPretrainedModel.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", config = v_config)

model = Qwen2_5_VisionTransformerPretrainedModel(v_config)

In [None]:
model.to(device) 
inputs.to(device)

with torch.profiler.profile(
       activities=[
           torch.profiler.ProfilerActivity.CPU,
           torch.profiler.ProfilerActivity.CUDA,
       ],
       schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
       record_shapes=True,
       profile_memory=True,
       with_stack=True
   ) as prof:
    v_embedding = model.forward(inputs['pixel_values_videos'], inputs['video_grid_thw'])
    
    prof.step()

In [None]:
v_embedding.shape

# DIMENSIONS:

# 1 video: INPUT: torch.Size([128, 1176]), OUTPUT: torch.Size([32, 2048])
# 2 videos: INPUT: torch.Size([576, 1176]), OUTPUT: torch.Size([144, 2048])
# 2 videos, merge size 4: INPUT: torch.Size([576, 1176]), OUTPUT: torch.Size([36, 2048])

In [None]:
import matplotlib
prof.export_memory_timeline(f"1_video_4_fps_memory_check.html", device="cuda:0")

# Forward & Backward run testing

In [None]:
# Experiment class
import math

import numpy as np
import torch
import lightning as pl
import transformers
from transformers import AutoTokenizer, AutoModel

import contrastive_encoders.encoders as encoders
import contrastive_encoders.losses as losses

import bitsandbytes as bnb


class VideoTextExp(pl.LightningModule):
    def __init__(
        self, 
        video_encoder_cfg,
        text_encoder_cfg,
        #optimizer,
        sample_rate: int = 16000,
        initial_lr: float = 1e-4,
        weight_decay: float = 1e-4,
        num_warmup_steps: int = 0,
        hard_negatives: bool = False,
        tokenizer = None,
        processor = None,
        text = False
    ):
        super().__init__()

        self.save_hyperparameters()

        #print("text encoder cfg:")
        #print(self.hparams.text_encoder_cfg)
        #print(self.hparams.video_encoder_cfg)
        self.text_encoder = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", 
                                                      trust_remote_code=True)
        
        print("text_encoder initiated")
        
        self.video_encoder = encoders.initialize_vision_encoder(self.hparams.video_encoder_cfg)
        print("video_encoder initiated")
        
        self.loss = losses.ContrastiveSigmoid
        
        # t and b for the loss function, should be set in the yaml file?
        self.t_prime = torch.tensor(math.log(10))
        self.b = torch.tensor(-10.0)
        
        self.hard_negatives = hard_negatives
        self.text = text 
        self.validation_step_outputs = []
        
        if tokenizer is not None:
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        if processor is not None:
            self.processor = AutoProcessor.from_pretrained(processor)
            
        #self.zeroshot = zeroshot 

    def configure_optimizers(self):
        model_params = [
            {"params": self.video_encoder.parameters()},
            {"params": self.text_encoder.parameters()}
        ]
        
        # Using 8-bit adam
        optimizer = bnb.optim.Adam8bit(model_params, 
                                       lr = self.hparams.initial_lr, 
                                       weight_decay = self.hparams.weight_decay)

        max_steps = self.trainer.max_steps 
        
        scheduler = transformers.get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps = self.hparams.num_warmup_steps,
            num_training_steps = max_steps
        )
        
        return(
            [optimizer],
            [{"scheduler": scheduler, "interval": "step"}]
        )
        

    def forward(self, video_input, text_input):
        text_features = self.text_encoder.encode(text_input, 
                                                 task=self.hparams.text_encoder_cfg["task"], 
                                                 truncate_dim=self.hparams.text_encoder_cfg["out_hidden_size"])
        text_features = torch.tensor(text_features)
        
        video_features = self.encode_video(video_input) 
        return video_features, text_features

    def encode_video(self, video_input):
        pixel_values = video_input['pixel_values_videos']
        grid_thw = video_input['video_grid_thw']
        
        video_features = self.video_encoder(pixel_values, grid_thw)

        return video_features

    def encode_text(self, text_input):
        text_features = self.text_encoder(**text_input)
        return text_features

    def training_step(self, batch, batch_idx):
        
        video_input, text_input = batch
        video_features, text_features = self.forward(video_input, text_input)
        loss = self.loss(video_features, 
                         text_features, 
                         self.t_prime,
                         self.b
                        )
    
        self.log("loss", loss, prog_bar=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        video_input, text_input = batch
        video_features, text_features = self.forward(video_input, text_input)
        loss = self.loss(video_features, 
                         text_features, 
                         self.t_prime,
                         self.b
                        )
        
        self.validation_step_outputs.append(loss)


    def validation_epoch_end(self, outputs):
        
        if self.global_rank == 0:
            avg_loss = torch.stack([x["val_loss"] for x in self.validation_step_outputs]).mean()
            self.log("val_loss", avg_loss, sync_dist = True)

    

In [None]:
# Initializing model
import yaml
from transformers import AutoConfig

with open("config/base.yaml", "r") as file:
    config = yaml.safe_load(file)

text_encoder_cfg = config["model"]["init_args"]["text_encoder_cfg"]
video_encoder_cfg = config["model"]["init_args"]["video_encoder_cfg"]


#text_encoder_cfg = AutoConfig.from_pretrained("jinaai/jina-embeddings-v3")
#video_encoder_cfg = AutoConfig.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")

# Pass the loaded config as arguments to the model
model = VideoTextExp(text_encoder_cfg=text_encoder_cfg, video_encoder_cfg=video_encoder_cfg)

# Print to verify
#print(model.hparams)

In [None]:
# Sample Data preprocessing
example_text = [
    "This is an example sentence.",
    "this short sentence.",
    #"I have an extrememly long sentence that I want to test the text-encoder on."
]

import numpy as np
from qwen_vl_utils import process_vision_info
from transformers import Qwen2VLProcessor

video_path_1 = "/scratch/st-jzhu71-1/ewong25/my_jupyter/How2Sign_Clips/val/bIUmw2DVW7Q_9-3-rgb_front.mp4"
video_path_2 = "/scratch/st-jzhu71-1/ewong25/my_jupyter/How2Sign_Clips/val/a4Nxq0QV_WA_4-5-rgb_front.mp4"
video_path_3 = "/scratch/st-jzhu71-1/ewong25/my_jupyter/How2Sign_Clips/val/a5yNwUSiYpA_11-3-rgb_front.mp4"

messages = [{"role": "user", 
             "content": [{
                    "video": video_path_2,
                    "min_pixels": 4096, #128*128
                    "max_pixels": 16384,  # 224*224
                }]
              },
            # {"role": "user", 
            #  "content": [{
            #         "video": video_path_3,
            #         "min_pixels": 4096, #128*128
            #         "max_pixels": 16384,  # 224*224
            #     }]
            #   }
            #{"role": "user", "content": [{"video": video_path_3}]}]
            ]


# Video Preprocessing
image_inputs, video_inputs, video_kwargs = process_vision_info([messages], return_video_kwargs=True)

processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")

v_inputs = processor(text = "", videos=video_inputs, padding=True, return_tensors="pt", mergesize = 4)
v_inputs

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) 
v_inputs.to(device)
example_text.to(device)

with torch.profiler.profile(
       activities=[
           torch.profiler.ProfilerActivity.CPU,
           torch.profiler.ProfilerActivity.CUDA,
       ],
       schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
       record_shapes=True,
       profile_memory=True,
       with_stack=True
   ) as prof:
    
    video_features, text_features = model.forward(v_inputs, example_text)
    loss = model.loss(video_features, text_features, model.t_prime, model.b)
    loss.backward()