# Finetune Video-LLaVA on custom temporal dataset

In this notebook, we are going to fine-tune the [Video-LLaVa](https://huggingface.co/docs/transformers/main/en/model_doc/video_llava) model on [a-temporal-upgrade](https://huggingface.co/datasets/jwnt4/a-temporal-upgrade/viewer/moment_retrieval) dataset which is comprised of 2 video-related task: action ordering and moment retrieval. However, in this notebook, we are only going to train for moment-retrieval

Video-LLaVa is an open-source multimodal model that can accept both, images and videos as input in an interleaved manner. The model architecture is pretty much similar to [LLaVa](https://huggingface.co/docs/transformers/main/en/model_doc/llava). However Video-LLaVa leverages a new universal visial encoder to seemlessly handle both visual modes. As we'll see, fine-tuning these various models is pretty similar as their API is mostly the same.

* Video-LLaVa [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/video_llava)
* Video-LLaVa [checkpoint on the hub](https://huggingface.co/LanguageBind/Video-LLaVA-7B-hf)

The goal for the model in this notebook is to train an adapter layer on top of the model with LoRA and enhance it's temporal awareness or capability.


The whole training consists of 3 stages.

**Stage 1 (this notebook)**
- Training dataset: 640 samples
- Training loss: Cross Entropy Loss
- Validation: IoU
- Gradient accumulation step: 2
- Val size: 24
- Val steps: 160
- Early stopping: IoU



In [1]:
from project.dataset.collate import DataCollatorWithPadding
from project.dataset.prepare import MomentRetrievalDataset
from project.trainer.lightning import VideoLlavaModelPLModule
from project.trainer.peft import find_all_linear_names
from project.dataset.utils import view_sample_with_video
from project.trainer.metrics import mr_iou_score

from transformers import (
    VideoLlavaProcessor,
    BitsAndBytesConfig,
    VideoLlavaForConditionalGeneration,
    LlamaForCausalLM
)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from datasets import load_from_disk, load_dataset
from torch.utils.data import DataLoader
import torch
from dataclasses import dataclass
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
import argparse
import logging
import os
import sys
from datetime import datetime
from lightning.pytorch.strategies import DeepSpeedStrategy

[2024-12-19 14:19:43,434] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to mps (auto detect)


W1219 14:19:43.831000 13398 torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


  warn("The installed version of bitsandbytes was compiled without GPU support. "


In [2]:
@dataclass
class Config:
    lora_r: int = 8
    lora_alpha: int = 16
    batch_size: int = 2
    max_epoch: int = 2
    val_check_interval: float = 0.25
    learning_rate: float = 2e-5
    dataset_dir: str = "datasets/processed"
    num_frames: int = 14
    num_worker: int = 2
    hub_repo: str = "jwnt4/finetune-videollava-qlora"
    accumulate_grad_batches: int = 4
    limit_val_batches: float = 24

args = Config

In [3]:
torch.set_float32_matmul_precision('high')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
stream_handler = logging.StreamHandler(sys.stderr)

if not os.path.isdir("logs"):
    os.makedirs("logs")

log_file = f"logs/{str(datetime.now()).replace(' ', '_')}.log"
file_handler = logging.FileHandler(log_file)

log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S")
stream_handler.setFormatter(log_formatter)

logger.addHandler(stream_handler)
logger.addHandler(file_handler)

The processor for this model is a Fast Tokenizer based on Rust. The patch size is the pooling stride for each frame. Each frame is converted to 224x224 sized feature-map.

In [4]:
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
processor.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

We use a custom dataset preparer class to hide the complexity of loading and transforming the dataset. We will now see the raw dataset.

In [5]:
raw_dataset = load_dataset("jwnt4/a-temporal-upgrade", "moment_retrieval")
raw_sample = raw_dataset['train'][5]
for k,v in raw_sample.items():
    print(f"{k}:\n{v}\n")

video_id:
v_2Sev8z4P7pE

duration:
45.279998779296875

prompt_frame:
Your task is to determine the frame range that best represents an action in the video. Provide your answer as two numbers separated by a comma, where the first number is the first frame correlated to the action and the second number is the last frame correlated to the action. Do not provide any other explanation in your response.
Here is an example: Suppose the video has frames numbered from 1 to N (N is the number of frames sampled). If the action in question is "a man sings on the street" and the frames that has the most similarities with this action are 5, 6, and 7 your answer should be: "5, 7" (this is an example).
Number of frames sampled in this video: <num_frames>
Here is the video context: A man rake dead leaves in a backyard.
Here is the action in question: <action>
Here is the question: What is the frame range (start, end) in the video that best represents the action asked?

prompt_timestamp:
Your task is to

Each sample has:
- video_id: id of the video
- duration: video duration
- prompt_frame: frame-styled prompt (asks the frame number of the action to the model)
- prompt_timestamp: timestamp-styled prompt (asks the timestamp of the action to the model)
- action: the action in the video, which the model has to find the timestamp / frame
- answer: [start, end] time of the said action

Now, we will transform the dataset to the format that is understandable by the model:

For training sample, the format is:
- input_ids: Tokenized inputs with shape (Batch, Token)
- attention_mask: 1s and 0s tensor indicating which token to attend for the model with shape (Batch, Token)
- labels: input_ids but the pad tokens are swapped with -100 (ignore index of cross entropy) with shape (Batch, Token),
- pixel_values_videos: pixel values of the videos, shape is (Batch, Frame, Channel, Height, Width)

For the eval or test sample, the format is:
- input_ids: Tokenized inputs with shape (Batch, Token)
- attention_mask: 1s and 0s tensor indicating which token to attend for the model with shape (Batch, Token)
- labels: array with shape (Batch, 2), each element is the answer to the question, e.g. ['00:20', '00:40']
- pixel_values_videos: pixel values of the videos, shape is (Batch, Frame, Channel, Height, Width)
- ts_info: timestamp information for evaluation with shape (Batch, Frame). 

In [6]:
base_dir = args.dataset_dir.split("/")[0]
processed_dir = args.dataset_dir.split("/")[1]
dp = MomentRetrievalDataset(
    base_dir=base_dir, processed_dir=processed_dir, num_frames=args.num_frames, num_worker=1, processor=processor
)

In [7]:
dataset = None
try:
    dataset = load_from_disk(f"{args.dataset_dir}/moment_retrieval/timestamp/{args.num_frames}_frames")
except:
    dataset = None
if dataset is None:
    dataset = dp.prepare_dataset(use_frame=False)

In [10]:
train_dataset = dataset['train']
eval_dataset = dataset['validation']
train_dataloader = DataLoader(dataset['train'], collate_fn=DataCollatorWithPadding(processor), batch_size=args.batch_size, shuffle=False, num_workers=1)
eval_dataloader = DataLoader(dataset['validation'], collate_fn=DataCollatorWithPadding(processor), batch_size=1, shuffle=False, num_workers=1)


In [14]:
example = next(iter(train_dataloader))
view_sample_with_video({"pixel_values_videos": example[2], "input_ids": example[0]}, processor)

prompt:
USER: Your task is to determine the timestamp range that best represents an action in the video. Use the provided frame-to-timestamp mapping to associate the timestamps with the actual video frames. Find the most similar continuous sequence of timestamp with the action asked.
Provide your answer as two timestamps in the format "mm:ss, mm:ss" (e.g. "00:10, 00:30"), where the first timestamp is the start time of the action and the second timestamp is the end time of the action. Do not provide any other explanation in your response. 
Video duration: 22 seconds
Frames sampled: 14
The frame-to-timestamp mapping for this video:
Frame 1 at 00:00
Frame 2 at 00:02
Frame 3 at 00:03
Frame 4 at 00:05
Frame 5 at 00:07
Frame 6 at 00:08
Frame 7 at 00:10
Frame 8 at 00:12
Frame 9 at 00:13
Frame 10 at 00:15
Frame 11 at 00:17
Frame 12 at 00:18
Frame 13 at 00:20
Frame 14 at 00:22
Video context: A calico cat is seen sitting on a white sheet. 
Action in question: The cat closes its eyes as it grooms

In [17]:
_input_ids, _attention_mask, _pixel_values_videos, _labels = example
print(f"shape of input_ids (Batch, Token): {_input_ids.shape}")
print(f"shape of attention_mask (Batch, Token): {_attention_mask.shape}")
print(f"shape of pixel_values_videos (Batch, Frame, Channel, Height, Width): {_pixel_values_videos.shape}")
print(f"shape of labels (Batch, Token): {_labels.shape}")

shape of input_ids (Batch, Token): torch.Size([2, 3996])
shape of attention_mask (Batch, Token): torch.Size([2, 3996])
shape of pixel_values_videos (Batch, Frame, Channel, Height, Width): torch.Size([2, 14, 3, 224, 224])
shape of labels (Batch, Token): torch.Size([2, 3996])


## Load model

This model has 7 billion parameters. The model has undergone supervised fine-tuning on videochat instruction dataset. 

We load this model with quantization and only train lora layers on top of the model. LoRA or low-rank adaptaion allows to just freeze the existing weight and train a couple of adapter layers on top of the base model. The quantization in use is from bitsandbytes and it allows for each parameter to be just 4 bits instead of 32 bits. This allows us to train the model with batch size = 2 in a 24gb gpu

In [11]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = VideoLlavaForConditionalGeneration.from_pretrained(
    "LanguageBind/Video-LLaVA-7B-hf",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    device_map="auto"
)

model.generation_config.max_new_tokens = 32
model.config.return_dict = True

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

After we load the base model. We add lora adapter layers.

In [12]:
lora_config = LoraConfig(
    r=args.lora_r,
    lora_alpha=args.lora_alpha,
    lora_dropout=0.1,
    target_modules=find_all_linear_names(model),
    init_lora_weights="gaussian",
    task_type="CAUSAL_LM"
) 

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

Define a custom LightningModule to control the eval loop and train loop. With this, we do not have to add boilerplate code such as moving the data to gpu, moving it back, backward propagation, gradient checkpointing, and many more. 

In [13]:
from lightning import LightningModule
from project.trainer.metrics import ao_exact_score, mr_iou_score
from deepspeed.ops.adam import DeepSpeedCPUAdam
from torch import Tensor
from bitsandbytes.optim.adam import Adam8bit


class VideoLlavaModelPLModule(LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.save_hyperparameters(ignore=["model"])
        self.config = config
        self.processor = processor
        self.model = model


    def training_step(self, batch):
        input_ids: Tensor
        attention_mask: Tensor
        pixel_values_videos: Tensor
        labels: Tensor
        input_ids, attention_mask, pixel_values_videos, labels = batch

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values_videos=pixel_values_videos,
            labels=labels
        )
        loss = outputs.loss
        self.log("train_loss", loss)

        return loss


    def validation_step(self, batch):

        input_ids, attention_mask, pixel_values_videos, labels, frame_info = batch

        # autoregressively generate token IDs
        generated_ids = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values_videos=pixel_values_videos,
            max_new_tokens=50,
            do_sample=False,
        )
        # turn them back into text, chopping of the prompt
        predictions = self.processor.batch_decode(
            generated_ids[:, input_ids.size(1):], 
            skip_special_tokens=True, clean_up_tokenization_spaces=True)
        frame_info = batch[-1]
        score, correct = mr_iou_score(predictions, frame_info, labels) 
            
        self.log("val_accuracy", score)

        return correct


    def configure_optimizers(self):
        # use 8 bit optimizer
        optimizer = Adam8bit(self.parameters(), min_8bit_size=4096, lr=self.config.get("lr"))
        # optimizer = DeepSpeedCPUAdam(self.parameters(), lr=2e-5)
        # optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.get("lr"))

        return optimizer

In [14]:
module = VideoLlavaModelPLModule(
    config={
        "lr": args.learning_rate
    },
    processor=processor,
    model=model
)

In [15]:
limit_val_batches = (args.limit_val_batches // args.batch_size) * args.batch_size
train_conf = {
    "max_epochs": args.max_epoch,
    "accumulate_grad_batches": 1,
    "limit_val_batches": int(limit_val_batches),
    "val_check_interval": args.val_check_interval,
    "precision": "16-mixed",
    "gradient_clip_val": 1.0,
    "num_sanity_val_steps": None
}
print(train_conf)

{'max_epochs': 2, 'accumulate_grad_batches': 1, 'limit_val_batches': 24, 'val_check_interval': 0.25, 'precision': '16-mixed', 'gradient_clip_val': 1.0, 'num_sanity_val_steps': None}


We define 2 callbacks: early stopping and model checkpoint. Early stopping stops the training if the model does not do better in evaluation (patience is 3 by default) and model checkpoint saves the module state (model, processor, hparams, and more) so that training can be continued.

In [16]:
early_stopping = EarlyStopping(monitor="val_accuracy", verbose=False, mode="min")
model_checkpoint = ModelCheckpoint(
    monitor='val_accuracy',
    dirpath='output/',
    filename='videollava-7b-ao-{epoch:02d}-{val_accuracy:.2f}'+f"lora_r{args.lora_r}-lora_alpha{args.lora_alpha}"
)
callbacks = [
    early_stopping, model_checkpoint
]

In [17]:
trainer = Trainer(
    **train_conf,
    accelerator="auto",
    devices=[0],
    callbacks=callbacks,
)

Using 16bit Automatic Mixed Precision (AMP)
/opt/conda/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py:54: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


## Train

In [18]:
trainer.validate(module,eval_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.39441820979118347
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_accuracy': 0.39441820979118347}]

In [19]:
trainer.fit(module, train_dataloader, eval_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                 | Params
-----------------------------------------------
0 | model | PeftModelForCausalLM | 3.8 B 
-----------------------------------------------
27.2 M    Trainable params
3.8 B     Non-trainable params
3.8 B     Total params
15,371.895Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

  return fn(*args, **kwargs)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


In [21]:
trainer.validate(module,eval_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.5136803388595581
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_accuracy': 0.5136803388595581}]

In [26]:
trainer.model.model.push_to_hub(Config.hub_repo, commit_message="stage-1")

adapter_model.safetensors:   0%|          | 0.00/109M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/jwnt4/finetune-videollava-qlora/commit/00e3c2933eb0730cd834938d3f6f787a94a7a181', commit_message='stage-1', commit_description='', oid='00e3c2933eb0730cd834938d3f6f787a94a7a181', pr_url=None, repo_url=RepoUrl('https://huggingface.co/jwnt4/finetune-videollava-qlora', endpoint='https://huggingface.co', repo_type='model', repo_id='jwnt4/finetune-videollava-qlora'), pr_revision=None, pr_num=None)

In [28]:
trainer.model.processor.push_to_hub(Config.hub_repo, commit_message="stage-1 processor")

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/jwnt4/finetune-videollava-qlora/commit/2a68c3634b86cc140db2042e904c218dbe996c67', commit_message='stage-1 processor', commit_description='', oid='2a68c3634b86cc140db2042e904c218dbe996c67', pr_url=None, repo_url=RepoUrl('https://huggingface.co/jwnt4/finetune-videollava-qlora', endpoint='https://huggingface.co', repo_type='model', repo_id='jwnt4/finetune-videollava-qlora'), pr_revision=None, pr_num=None)

## Inference with the new model

### Old model

In [1]:
# del model
# del trainer

In [11]:
test_dataloader = DataLoader(dataset['test'], collate_fn=DataCollatorWithPadding(processor), batch_size=1, shuffle=False, num_workers=1)

In [12]:
model = VideoLlavaForConditionalGeneration.from_pretrained(
    "LanguageBind/Video-LLaVA-7B-hf",
    torch_dtype=torch.float16,
    _attn_implementation="flash_attention_2",
    device_map="auto"
)

model.generation_config.max_new_tokens = 32
model.config.return_dict = True

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [13]:
final_score = 0
with torch.inference_mode():
    for batch in test_dataloader:
        input_ids, attention_mask, pixel_values_videos, labels, frame_info = batch

        # autoregressively generate token IDs
        generated_ids = model.generate(
            input_ids=input_ids.to(model.device),
            attention_mask=attention_mask.to(model.device),
            pixel_values_videos=pixel_values_videos.to(model.device),
            max_new_tokens=50,
            do_sample=False,
        )
        # turn them back into text, chopping of the prompt
        predictions = processor.batch_decode(
            generated_ids[:, input_ids.size(1):], 
            skip_special_tokens=True, clean_up_tokenization_spaces=True)
        print(f"prediction: {predictions}. Labels: {labels}")
        frame_info = batch[-1]
        score, _ = mr_iou_score(predictions, frame_info, labels) 
        final_score += score

prediction: ['00:00, 00:29']. Labels: [['00:15', '00:29']]
prediction: ['00:31, 00:40']. Labels: [['00:25', '00:40']]
prediction: ['00:10, 00:30']. Labels: [['00:10', '00:33']]
prediction: ['00:00, 00:30']. Labels: [['00:06', '00:41']]
prediction: ['00:10, 00:30']. Labels: [['00:15', '00:32']]
prediction: ['00:10, 00:30']. Labels: [['00:11', '00:20']]
prediction: ['00:30, 00:45']. Labels: [['00:00', '00:39']]
prediction: ['00:10, 00:30']. Labels: [['00:05', '00:34']]
prediction: ['00:10, 00:30']. Labels: [['00:10', '00:15']]
prediction: ['00:10, 00:30']. Labels: [['00:21', '00:35']]
prediction: ['00:10, 00:30']. Labels: [['00:10', '00:11']]
prediction: ['00:38, 00:50']. Labels: [['00:27', '00:46']]
prediction: ['00:00, 00:30']. Labels: [['00:05', '00:23']]
prediction: ['00:00, 00:39']. Labels: [['00:06', '00:39']]
prediction: ['00:00, 00:30']. Labels: [['00:12', '00:28']]
prediction: ['00:10, 00:30']. Labels: [['00:14', '00:26']]
prediction: ['00:00, 00:30']. Labels: [['00:07', '00:23'

In [14]:
f"Score for original model: {final_score/135:.2f}"

'Score for original model: 0.42'

### Finetuned model

In [15]:
del model

In [18]:
model = VideoLlavaForConditionalGeneration.from_pretrained(
    Config.hub_repo,
    torch_dtype=torch.float16,
    device_map="auto"
)

model.generation_config.max_new_tokens = 32
model.config.return_dict = True
model

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

VideoLlavaForConditionalGeneration(
  (video_tower): CLIPVisionModel(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(257, 1024)
      )
      (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-23): 24 x CLIPEncoderLayer(
            (self_attn): CLIPSdpaAttention(
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1024, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, 

In [19]:
new_score = 0
with torch.inference_mode():
    for batch in test_dataloader:
        input_ids, attention_mask, pixel_values_videos, labels, frame_info = batch

        # autoregressively generate token IDs
        generated_ids = model.generate(
            input_ids=input_ids.to(model.device),
            attention_mask=attention_mask.to(model.device),
            pixel_values_videos=pixel_values_videos.to(model.device),
            max_new_tokens=50,
            do_sample=False,
        )
        # turn them back into text, chopping of the prompt
        predictions = processor.batch_decode(
            generated_ids[:, input_ids.size(1):], 
            skip_special_tokens=True, clean_up_tokenization_spaces=True)
        print(f"prediction: {predictions}. Labels: {labels}")
        frame_info = batch[-1]
        score, _ = mr_iou_score(predictions, frame_info, labels) 
        new_score += score

prediction: ['00:07, 00:29']. Labels: [['00:15', '00:29']]
prediction: ['00:22, 00:40']. Labels: [['00:25', '00:40']]
prediction: ['00:08, 00:20']. Labels: [['00:10', '00:33']]
prediction: ['00:13, 00:38']. Labels: [['00:06', '00:41']]
prediction: ['00:22, 00:32']. Labels: [['00:15', '00:32']]
prediction: ['00:07, 00:20']. Labels: [['00:11', '00:20']]
prediction: ['00:21, 00:39']. Labels: [['00:00', '00:39']]
prediction: ['00:24, 00:34']. Labels: [['00:05', '00:34']]
prediction: ['00:08, 00:33']. Labels: [['00:10', '00:15']]
prediction: ['00:16, 00:35']. Labels: [['00:21', '00:35']]
prediction: ['00:08, 00:21']. Labels: [['00:10', '00:11']]
prediction: ['00:34, 00:50']. Labels: [['00:27', '00:46']]
prediction: ['00:05, 00:27']. Labels: [['00:05', '00:23']]
prediction: ['00:06, 00:39']. Labels: [['00:06', '00:39']]
prediction: ['00:24, 00:31']. Labels: [['00:12', '00:28']]
prediction: ['00:06, 00:26']. Labels: [['00:14', '00:26']]
prediction: ['00:05, 00:28']. Labels: [['00:07', '00:23'

In [20]:
f"Score for trained model: {new_score/135:.2f}"

'Score for trained model: 0.55'