In [1]:
import os
import sys

import torch
import numpy as np
import pandas as pd

import shared.utils as su

**Load data**

In [2]:
csv_path = "/scratch/shared/beegfs/piyush/datasets/SimCSE-NLI/nli-27k+ego4d-3k.csv"
df_train = pd.read_csv(csv_path)
df_train.shape

(30000, 3)

In [3]:
df_train.iloc[0].to_dict()

{'sent0': 'Male street vendor selling an ear of corn.',
 'sent1': 'The street vendor is outside.',
 'hard_neg': 'The street vendor is under arrest.'}

In [4]:
from notebooks.eval_care import load_data
df_valid = load_data(dataset='ssv2', split='validation')
df_valid.shape

Number of rows:  1430
Sample row: 
{
    "id": 69703,
    "label": "moving pen up",
    "template": "Moving [something] up",
    "placeholders": "['pen']",
    "target": 114,
    "chiral_label": 0.0,
    "chiral_triplet_id": "3f20f09b",
    "noun": "['something']",
    "text_id": "3f20f09b_0.0",
    "video_path": "/scratch/shared/beegfs/piyush/datasets/SSv2/20bn-something-something-v2/69703.webm"
}


(1430, 10)

**Load model**

In [5]:
from models.modeling_encoders import AutoEncoder

In [6]:
encoder = AutoEncoder.from_pretrained(
    "/work/piyush/pretrained_checkpoints/CaRe-7B-Stage-1/",
    device_map='auto',
    dtype=torch.bfloat16,
    attn_implementation='flash_attention_2',
)
encoder

`torch_dtype` is deprecated! Use `dtype` instead!


Loading EncoderForCaRe from /work/piyush/pretrained_checkpoints/CaRe-7B-Stage-1/


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.


<models.modeling_encoders.EncoderForCaRe at 0x7f065febf2d0>

In [7]:
su.misc.num_params(encoder.model)

::: Number of total parameters in Qwen2VLForConditionalGeneration: 8291.376M


In [8]:
from utils.video import read_frames_decord

video_path = df_valid.video_path[0]
n_frames = 16
video = read_frames_decord(video_path, n_frames, width=480, height=270)
video.shape

torch.Size([16, 3, 270, 480])

In [21]:
with torch.no_grad():
    zv = encoder.encode_vision(pixel_values=video.unsqueeze(0))
    zt = encoder.encode_text(['A dog'])
zv.shape, zt.shape

(torch.Size([1, 3584]), torch.Size([1, 3584]))

In [16]:
# with torch.no_grad():
#     zv = encoder.encode_vision(pixel_values=torch.stack([video] * 12)).cpu()
# zv.shape

In [17]:
from notebooks.eval_care_retrieval import compute_metrics


def validation_metrics(encoder, df):

    # Compute video metrics
    video_paths = df.video_path.unique()
    video_ids = df.id.unique()
    video_feat = {}
    j = 0
    for video_path in su.log.tqdm_iterator(video_paths, desc='Computing video features'):
        video = read_frames_decord(video_path, n_frames, width=480, height=270)
        with torch.no_grad():
            zv = encoder.encode_vision(pixel_values=video.unsqueeze(0)).cpu().squeeze(0)
        zv = torch.nn.functional.normalize(zv, dim=-1)
        video_feat[video_ids[j]] = zv.cpu().float()
        j += 1

    # Compute text features
    text_ids = df['text_id'].unique()
    texts_feat = {}
    for text_id in su.log.tqdm_iterator(text_ids, desc='Computing text features'):
        text = df[df.text_id == text_id].template.unique()[0]
        with torch.no_grad():
            zt = encoder.encode_text([text]).cpu().squeeze(0)
        zt = torch.nn.functional.normalize(zt, dim=-1)
        texts_feat[text_id] = zt.cpu().float()

    metrics = compute_metrics(df, video_feat, texts_feat, show_metrics=False)
    return metrics


metrics = validation_metrics(encoder, df_valid)

Computing video features:   0%|          | 0/1430 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [23]:
su.misc.num_trainable_params(encoder.model)

::: Number of trainable parameters in PeftModelForSeq2SeqLM: 2.523 M


In [22]:
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules=["q_proj", "v_proj"],
    task_type="SEQ_2_SEQ_LM",
)

encoder.model = get_peft_model(encoder.model, peft_config)
encoder.model.print_trainable_parameters()

trainable params: 2,523,136 || all params: 8,293,898,752 || trainable%: 0.0304


In [21]:
su.misc.num_trainable_params(encoder.model)

::: Number of trainable parameters in PeftModelForCausalLM: 2.523 M


In [24]:
metrics = validation_metrics(encoder, df_valid)

Computing video features:   0%|          | 0/1430 [00:00<?, ?it/s]

KeyboardInterrupt: 

#### Failed attempts at speeding up video feature computation

In [12]:
import torch
from torch.utils.data import Dataset, DataLoader

class VideoDataset(Dataset):
    def __init__(self, video_paths, video_ids, n_frames, width=480, height=270):
        self.video_paths = video_paths
        self.video_ids = video_ids
        self.n_frames = n_frames
        self.width = width
        self.height = height
    
    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        video_id = self.video_ids[idx]
        
        # Load video frames
        video = read_frames_decord(video_path, self.n_frames, 
                                   width=self.width, height=self.height)
        
        return {
            'video': video,
            'video_id': video_id,
            'video_path': video_path
        }


class TextDataset(Dataset):
    def __init__(self, df):
        self.text_ids = df['text_id'].unique()
        self.texts = [df[df.text_id == tid].template.unique()[0] 
                      for tid in self.text_ids]
    
    def __len__(self):
        return len(self.text_ids)
    
    def __getitem__(self, idx):
        return {
            'text_id': self.text_ids[idx],
            'text': self.texts[idx]
        }


def validation_metrics_faster(encoder, df, n_frames=16, batch_size=4, num_workers=4):
    """
    Compute validation metrics with parallelized data loading.
    
    Args:
        encoder: The encoder model
        df: DataFrame containing video_path, id, text_id, template columns
        n_frames: Number of frames to extract from each video
        batch_size: Batch size for DataLoader (for loading, not model inference)
        num_workers: Number of workers for parallel data loading
    
    Returns:
        metrics: Dictionary of computed metrics
    """
    
    # ==================== Compute Video Features ====================
    video_paths = df.video_path.unique()
    video_ids = df.id.unique()
    
    video_dataset = VideoDataset(video_paths, video_ids, n_frames, 
                                 width=480, height=270)
    video_loader = DataLoader(
        video_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,  # Faster transfer to GPU
        prefetch_factor=2  # Prefetch 2 batches per worker
    )
    
    video_feat = {}
    
    for batch in su.log.tqdm_iterator(video_loader, desc='Computing video features'):
        videos = batch['video']  # Shape: (batch_size, C, T, H, W)
        batch_video_ids = batch['video_id']
        
        # Process each video in the batch individually to avoid OOM
        for i in range(len(videos)):
            video = videos[i].unsqueeze(0)  # Add batch dimension
            video_id = batch_video_ids[i]
            
            # Move to GPU and compute features
            if torch.cuda.is_available():
                video = video.cuda()
            
            with torch.no_grad():
                zv = encoder.encode_vision(pixel_values=video).cpu().squeeze(0)
            
            zv = torch.nn.functional.normalize(zv, dim=-1)
            video_feat[video_id] = zv.float()
            
            # Clean up
            del video
        
        # Clean up batch
        del videos
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # ==================== Compute Text Features ====================
    text_dataset = TextDataset(df)
    text_loader = DataLoader(
        text_dataset,
        batch_size=32,  # Can use larger batch for text
        shuffle=False,
        num_workers=2
    )
    
    texts_feat = {}
    
    for batch in su.log.tqdm_iterator(text_loader, desc='Computing text features'):
        batch_texts = batch['text']
        batch_text_ids = batch['text_id']
        
        with torch.no_grad():
            zt_batch = encoder.encode_text(list(batch_texts)).cpu()
        
        zt_batch = torch.nn.functional.normalize(zt_batch, dim=-1)
        
        # Store individually
        for text_id, zt in zip(batch_text_ids, zt_batch):
            texts_feat[text_id] = zt.float()
    
    # ==================== Compute Metrics ====================
    metrics = compute_metrics(df, video_feat, texts_feat, show_metrics=False)
    
    return metrics

metrics = validation_metrics_faster(encoder, df_valid, batch_size=8, num_workers=8)

Computing video features:   0%|          | 0/179 [00:00<?, ?it/s]

KeyboardInterrupt: 