<a href="https://colab.research.google.com/github/dancher00/vla/blob/main/HW1_Template%2C_fine_tune_MM_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Prerequisites
Before we start, make sure you have the following:

- Access to a GPU (preferably A100 since videos require high sequence lengths).
- Familiarity with Hugging Face’s Transformers library.
- Pre-install necessary packages by running the below.

In [None]:
# !pip install -U -q transformers accelerate bitsandbytes peft datasets
# !pip install -q av decord
# !pip install -q lightning
# !pip install -q pyarrow==15.0.0
# !pip install -q wandb
# !pip install -q timm deepspeed flash_attn
# restart notebooks here

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/13.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/13.6 MB[0m [31m58.0 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.4/13.6 MB[0m [31m75.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━[0m [32m9.1/13.6 MB[0m [31m85.4 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m13.1/13.6 MB[0m [31m107.3 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m13.6/13.6 MB[0m [31m106.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.6/13.6 MB[0m [31m66.1 MB/s[0m eta [36m0:00:00[0m
[?25h

## Fine-tune InternVL 1B. on MMBench dataset

In this notebook, you need to fine-tune the [InternVL](https://huggingface.co/OpenGVLab/InternVL3_5-1B) model on [MVBench](https://huggingface.co/datasets/OpenGVLab/MVBench) dataset which is comprised of various video-related tasks. Note that MMBench is quite small and is not made for tuning. So firstly you need to split it into training/testing parts.

The goal for the model in this notebook is to answer given multiple choice questions based on the video. The questions can be realetd to temporal aspects of the video, pose prediction and so on.
Sources:

* InternVL [documentation](https://internvl.readthedocs.io/en/latest/internvl2.0/introduction.html)
* InternVL [checkpoint on the hub](https://huggingface.co/OpenGVLab/InternVL2-1B)

## Define variables

We'll first set some variables useful througout this notebook and doo all the necessary imports.

In [None]:
import os
import av
import re
import gc
import sys
import json
import random
import bisect
import shutil
import traceback
import numpy as np
from nltk import edit_distance
from pathlib import Path
from copy import deepcopy
from typing import Dict, Any, List, Union, Tuple

from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          HfArgumentParser, Trainer, TrainingArguments,
                          set_seed, AutoProcessor)
from transformers import BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from huggingface_hub import snapshot_download, hf_hub_download
from datasets import load_dataset, concatenate_datasets
from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError

import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping, Callback


MAX_LENGTH = 160
MODEL_ID = "OpenGVLab/InternVL3_5-1B"
REPO_ID = "xxxx" # Change to your hf-hub repo

os.environ["WANDB_API_KEY"] = "xxxx" # Change to your W&B profile if you need it
os.environ["WANDB_MODE"] = "online"

from huggingface_hub import login
access_token = "xxxxx" # Change to your РА profile
login(access_token)

USE_LORA = False
USE_QLORA = True

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:
! git clone https://github.com/OpenGVLab/InternVL.git

Cloning into 'InternVL'...
remote: Enumerating objects: 2483, done.[K
remote: Counting objects: 100% (941/941), done.[K
remote: Compressing objects: 100% (173/173), done.[K
remote: Total 2483 (delta 797), reused 789 (delta 758), pack-reused 1542 (from 1)[K
Receiving objects: 100% (2483/2483), 36.32 MiB | 21.04 MiB/s, done.
Resolving deltas: 100% (1522/1522), done.


In [None]:
! ls InternVL/internvl_chat

eval	     pyproject.toml  zero_stage1_config.json	   zero_stage3_config_70b.json
evaluate.sh  README.md	     zero_stage2_config.json	   zero_stage3_config.json
examples     shell	     zero_stage3_config_100b.json
internvl     tools	     zero_stage3_config_34b.json


In [None]:
sys.path.append('./InternVL/internvl_chat')

In [None]:
from internvl.dist_utils import init_dist
# from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
from internvl.model.internvl_chat import (InternVisionConfig,
                                          InternVisionModel,
                                          InternVLChatConfig,
                                          InternVLChatModel)
from internvl.patch import (concat_pad_data_collator,
                            replace_llama_rmsnorm_with_fused_rmsnorm,
                            replace_train_sampler)
from internvl.train.constants import (BOX_END_TOKEN, BOX_START_TOKEN,
                                      IMG_CONTEXT_TOKEN, IMG_END_TOKEN,
                                      IMG_START_TOKEN, QUAD_END_TOKEN,
                                      QUAD_START_TOKEN, REF_END_TOKEN,
                                      IMAGENET_MEAN, IMAGENET_STD,
                                      REF_START_TOKEN)
from internvl.train.dataset import (ConcatDataset, read_frames_decord,
                                    WeightedConcatDataset, build_transform,
                                    dynamic_preprocess, preprocess,
                                    preprocess_internlm, preprocess_mpt,
                                    preprocess_phi3,)

# MVBench benchmark

[MVBench on HF Datasets](https://huggingface.co/datasets/OpenGVLab/MVBench)

<!-- ![MVbench1.png](https://huggingface.co/datasets/OpenGVLab/MVBench/resolve/main/assert/generation.png) -->

It consists of the 20 temporal task examples as follows.

![MVbench-structure.png](https://huggingface.co/datasets/OpenGVLab/MVBench/resolve/main/assert/task_example.png)


Here we have a nice viewer for each task:

[Dataset viewer](https://huggingface.co/datasets/OpenGVLab/MVBench/viewer/action_sequence)



We will start by downloading and processing the dataset. Even though MMBench is a small dataset, it still requires **around 1000B to store the videos**, so make sure you have enough free space.

First, we will use this mapping to get the datasets because each one is a separate subset in its own folder. Then we need a few helper functions to download videos and process them to fit the model's format (8 frames each video).

In [None]:
data_list = {
    "Action Sequence": ("action_sequence.json", "star/Charades_v1_480/", "video", True), # has start & end
    "Action Prediction": ("action_prediction.json", "star/Charades_v1_480/", "video", True), # has start & end
    "Action Antonym": ("action_antonym.json", "ssv2_video/", "video", False),
    "Fine-grained Action": ("fine_grained_action.json", "Moments_in_Time_Raw/videos/", "video", False),
    "Unexpected Action": ("unexpected_action.json", "FunQA_test/test/", "video", False),
    "Object Existence": ("object_existence.json", "clevrer/video_validation/", "video", False),
    "Object Interaction": ("object_interaction.json", "star/Charades_v1_480/", "video", True), # has start & end
    "Object Shuffle": ("object_shuffle.json", "perception/videos/", "video", False),
    "Moving Direction": ("moving_direction.json", "clevrer/video_validation/", "video", False),
    "Action Localization": ("action_localization.json", "sta/sta_video/", "video", True),  # has start & end
    "Scene Transition": ("scene_transition.json", "scene_qa/video/", "video", False),
    "Action Count": ("action_count.json", "perception/videos/", "video", False),
    "Moving Count": ("moving_count.json", "clevrer/video_validation/", "video", False),
    "Moving Attribute": ("moving_attribute.json", "clevrer/video_validation/", "video", False),
    "State Change": ("state_change.json", "perception/videos/", "video", False),
    "Fine-grained Pose": ("fine_grained_pose.json", "nturgbd/", "video", False),
    "Character Order": ("character_order.json", "perception/videos/", "video", False),
    "Egocentric Navigation": ("egocentric_navigation.json", "vlnqa/", "video", False),
    "Episodic Reasoning": ("episodic_reasoning.json", "tvqa/frames_fps3_hq/", "frame", True),  # has start & end, read frame
    "Counterfactual Inference": ("counterfactual_inference.json", "clevrer/video_validation/", "video", False),
}

data_dir = "dataset"
if not os.path.exists(data_dir):
    os.mkdir("dataset")

In [None]:
def read_video_pyav(video_path, start, end, n_frames=8):
    """
    Reads a video for given start-end timestamps interval
    and uniformly samples 8 frames of it
    """
    container = av.open(video_path)
    video = container.streams.get(0)[0]

    av_timestamps = [
        int(packet.pts * video.time_base) for packet in container.demux(video) if packet.pts is not None
    ]

    av_timestamps.sort()
    start_id = bisect.bisect_left(av_timestamps, start)
    end_id = bisect.bisect_left(av_timestamps, end)

    # in case it is a very short video, lets take a longer duration and sample
    if end_id  - start_id < 10:
        end_id += 10
        start_id -= 10

    end_id = min(len(av_timestamps) - 1, end_id)
    start_id = max(1, start_id)

    # We sample n_frames frames for tuning following the original paper
    # But we can increase the number of frames for longer videos and check out if it helps performance
    # Change the below "n_frames" to any number of frames you want, and note that more frames -> more computational resources needed
    indices = np.linspace(start_id, end_id, n_frames).astype(int)

    frames = []
    container.seek(0)
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_id:
            break
        if i >= start_id and i in indices:
            frames.append(frame)
    assert len(frames) == n_frames, f"Got {len(frames)} frames but should be {n_frames}. Check the indices: {indices};, start_id: {start_id}, end_id: {end_id}. Len of video is {len(av_timestamps)} frames."
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

In [None]:
def collate_read_video(example, path):
    # Some datasets have a start-end interval, so we try to get it if exists.
    # Otherwise just set a very large end timestamp
    clip = read_video_pyav(f'{path}/{example["video"]}', example.get("start", 1), example.get("end", 1e+10))
    example["clip"] = clip
    return example

In [None]:
# Download the videos from datasets repo and unzip.
# Make sure you have enough free space before downloading and unzipping

# videos = snapshot_download(repo_id="OpenGVLab/MVBench", allow_patterns="*", repo_type="dataset")
# for zip_file in os.listdir(f"{videos}/video"):
#     if zip_file.endswith(".zip"):
#         shutil.unpack_archive(f"{videos}/video/{zip_file}", f"{videos}/videos_unzipped/")

In [None]:
# Or download only selected task with appropriate files
# https://huggingface.co/docs/huggingface_hub/v0.24.7/en/package_reference/file_download#huggingface_hub.hf_hub_download

TASK_NAME = "Action Sequence"
annotation_fn, video_dir, video_type, has_clip = data_list.get(TASK_NAME)
print(f"Task: {TASK_NAME}")
print(f"Annotation file: {annotation_fn}")
print(f"Videos are stored in: {video_dir}")
print(f"Videos are represented as: {video_type}")
print(f"Videos are have start/end timestamps: {has_clip}")

annotation_fn_local = hf_hub_download(repo_id="OpenGVLab/MVBench",
                                    filename='json/' + annotation_fn,
                                    repo_type="dataset",
                                    local_dir=data_dir)
videos_zip = hf_hub_download(repo_id="OpenGVLab/MVBench",
                            filename='video/' + video_dir.split("/")[0] + ".zip",
                            repo_type="dataset",
                            local_dir=data_dir)

# Unzip
for zip_file in os.listdir(f"{data_dir}/video"):
    if zip_file.endswith(".zip"):
        print(zip_file)
        shutil.unpack_archive(f"{data_dir}/video/{zip_file}", f"{data_dir}/video/videos_unzipped/")

Task: Action Sequence
Annotation file: action_sequence.json
Videos are stored in: star/Charades_v1_480/
Videos are represented as: video
Videos are have start/end timestamps: True
star.zip


Make a following data structure:

```
dataset/
    /json
        task_name.json
    /video
        /task_name_prefix (optional)
            /task_name
                video_0.mp4
                video_1.mp4
                video_2.mp4
                ...
                video_100.mp4
```





In [None]:
ds = load_dataset("json", data_files=annotation_fn_local, split="train")
ds

Dataset({
    features: ['video', 'question', 'candidates', 'answer', 'start', 'end'],
    num_rows: 200
})

In [None]:
# Some tasks in MVBench are missing video files - keep it in mind!
has_missing = False
for sample in ds:
    if not os.path.exists(f"{data_dir}/video/videos_unzipped/{video_dir}/{sample['video']}"):
        print(f"Video `{sample['video']}` does not exists!")
        has_missing = True

Video `EDXBD.mp4` does not exists!
Video `K47J5.mp4` does not exists!
Video `9MNZ5.mp4` does not exists!
Video `QXT9W.mp4` does not exists!
Video `ABHC6.mp4` does not exists!
Video `ALXUC.mp4` does not exists!
Video `BAUQE.mp4` does not exists!
Video `PHH6B.mp4` does not exists!
Video `MNC10.mp4` does not exists!
Video `W7CR5.mp4` does not exists!
Video `Q8UJ8.mp4` does not exists!
Video `X9WTR.mp4` does not exists!


In [None]:
print(f"Dataset length = {len(ds)}")
if has_missing:
    ds = ds.filter(lambda x: os.path.exists(f"{data_dir}/video/videos_unzipped/{video_dir}/{x['video']}"))

print(f"Dataset length = {len(ds)}")

Dataset length = 200
Dataset length = 188


In [None]:
# Load videos and split them into frames
# ds = ds.map(collate_read_video,
#             batched=False,
#             fn_kwargs={"path": f"{data_dir}/video/videos_unzipped/{video_dir}"})

# Make conversation

def make_conversation(sample):
    id2choice = {0: "A", 1: "B", 2: "C", 3: "D"}
    question, candidates = sample["question"], sample["candidates"]
    answer_i = candidates.index(sample["answer"])
    answer_choice = id2choice[answer_i]

    mult_choice = "\n"
    for i, choice in enumerate(candidates):
        mult_choice += f"{id2choice[i]}. {choice};\n"

    conversations = [
         {'from': 'human', 'value': question + mult_choice + '<video>'},
         {'from': 'gpt', 'value': answer_choice + " " + sample["answer"]}
    ]
    sample['conversations'] = conversations
    return sample

ds = ds.map(make_conversation,
            batched=False)

Map:   0%|          | 0/188 [00:00<?, ? examples/s]

In [None]:
ds[0]

{'video': 'ZS9XR.mp4',
 'question': 'What happened after the person took the food?',
 'candidates': ['Ate the medicine.',
  'Tidied up the blanket.',
  'Put down the cup/glass/bottle.',
  'Took the box.'],
 'answer': 'Ate the medicine.',
 'start': 1.5,
 'end': 17.1,
 'conversations': [{'from': 'human',
   'value': 'What happened after the person took the food?\nA. Ate the medicine.;\nB. Tidied up the blanket.;\nC. Put down the cup/glass/bottle.;\nD. Took the box.;\n<video>'},
  {'from': 'gpt', 'value': 'A Ate the medicine.'}]}



```json
{'id': 578004, 'video': '013901_013950/1025449886.mp4',
'conversations':
[{'from': 'human', 'value': 'Render a clear and concise summary of the video below.\n<video>'},
{'from': 'gpt', 'value': 'Aerial; drone flight around dangerous eroded steep stony slopes of cabo da roca; granite boulders, sea cliffs along the high coast; long seashore with lighthouse, overlooking the promontory, portugal'}]
}
```



In [None]:
# Load model's processor
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
processor.padding_side = "right"

processor

Qwen2TokenizerFast(name_or_path='OpenGVLab/InternVL2-1B', vocab_size=151643, model_max_length=8192, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<|im_end|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>', '<img>', '</img>', '<IMG_CONTEXT>', '<quad>', '</quad>', '<ref>', '</ref>', '<box>', '</box>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151646: AddedToken("<img>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151647: AddedToken("</img>", rstrip=False, lstrip=False, single_word=False, normalized=Fals

## Custom Dataset Class

In the next step, you'll need **to define a custom dataset** class and the necessary functions to prepare our data for fine-tuning model. The VideoQADataset class extends the [PyTorch Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) class to facilitate loading and processing "MMBench". This class will handle the conversion of dataset samples into the format required for training and evaluation by preparing a prompt and making array from videos.

Next, you need **to define collate functions** to handle the batching of data during training and evaluation. These functions ensure that the input data is properly formatted and padded.

Here use the processor to turn the (video, target token sequence) into the format that the model expects (which is pixel_values, input_ids etc.). Use a dynamic padding of the batches: each batch contains ground truth sequences of varying lengths.

Also you can limit the length of the text tokens (input_ids) to a max length due to memory constraints, feel free to expand if your target token sequences are longer (I'd recommend plotting the average token length of your dataset to determine the optimal value).

The formatting of the input_ids is super important: you need to respect a so-called [chat template](https://huggingface.co/docs/transformers/main/en/chat_templating).

Labels are created for the model by simply copying the inputs to the LLM (input_ids), but with padding tokens replaced by the ignore index of the loss function. This ensures that the model doesn't need to learn to predict padding tokens (used to batch examples together).

Why are the labels a copy of the model inputs, you may ask? The model will internally shift the labels one position to the right so that the model will learn to predict the next token. This can be seen here.

The collate function for evaluation is different, since there you only need to feed the prompt to the model, as we'll use the `generate()` method to autoregressively generate a completion.

In [None]:
class VideoQADataset(Dataset):
    """Dataset for Video QA fine-tuning."""

    def __init__(
        self,
        template_name,
        raw_data,
        video_data_dir,
        tokenizer,
        ds_name,
        num_image_token,
        image_size=224,
        is_train=True,
        pad2square=False,
        dynamic_image_size=False,
        use_thumbnail=False,
        min_dynamic_patch=1,
        max_dynamic_patch=6,
        min_num_frame=4,  # for video data
        max_num_frame=12,  # for video data
        sampling_method='rand',  # for video data
        repeat_time=1,
        normalize_type='imagenet',
        random_seed=0,
    ):
        super(VideoQADataset, self).__init__()
        self.ds_name = ds_name
        self.tokenizer = tokenizer
        self.template_name = template_name
        self.num_image_token = num_image_token
        print(f'[Dataset] num_image_token: {num_image_token}')
        print(f'[Dataset] dynamic_image_size: {dynamic_image_size}')
        print(f'[Dataset] use_thumbnail: {use_thumbnail}')
        print(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}')

        self.image_size = image_size
        self.is_train = is_train
        self.pad2square = pad2square
        self.max_num_frame = max_num_frame
        self.min_num_frame = min_num_frame
        self.sampling_method = sampling_method

        print('Formatting inputs...Skip in lazy mode')

        self.raw_data = raw_data
        self.rng = np.random.default_rng(seed=random_seed)
        self.rng.shuffle(self.raw_data)

        gc.collect()
        self.root = video_data_dir
        self.cached_data_dict = {}
        self.dynamic_image_size = dynamic_image_size
        self.use_thumbnail = use_thumbnail
        self.min_dynamic_patch = min_dynamic_patch
        self.max_dynamic_patch = max_dynamic_patch
        self.normalize_type = normalize_type

        gc.collect()

    def __len__(self):
        return len(self.raw_data)

    def get_preprocess_function(self):
        # Select the appropriate preprocessing function based on the template name
        return preprocess_internlm

    def load_image(self, image_path):
        # Load the image using tcs_loader if available, otherwise use PIL
        if self.tcs_loader is not None and 's3://' in image_path:
            return self.tcs_loader(image_path)
        return Image.open(image_path).convert('RGB')

    def get_image_path(self, image_path):
        if image_path.startswith('s3://'):  # for ceph
            image_path = self.root + image_path
        else:  # for local image
            image_path = os.path.join(self.root, image_path)
        return image_path

    def get_transform(self):
        # Build transformation function
        transform = build_transform(is_train=self.is_train, input_size=self.image_size,
                                    pad2square=self.pad2square, normalize_type=self.normalize_type)
        return transform


    def video_get_item(self, data_item):
        # Build transformation function
        transform = self.get_transform()

        # Ensure the first conversation contains a video placeholder
        if '<video>' not in data_item['conversations'][0]['value']:
            data_item['conversations'][0]['value'] = '<video>\n' + data_item['conversations'][0]['value']

        # Get the video file path
        video_file = data_item['video']
        video_path = os.path.join(self.root, video_file)

        # Load the video frames using tcs_loader
        image_list = read_frames_decord(video_path,
                                        num_frames=self.num_frames,
                                        client=self.tcs_loader,
                                        clip=(data_item.get('start', None),
                                              data_item.get('end', None)))

        # Generate special tokens for each video frame
        special_tokens = '\n'.join(['Frame{}: <image>'.format(i + 1) for i in range(len(image_list))])
        data_item['conversations'][0]['value'] = data_item['conversations'][0]['value'].replace(
            '<video>', special_tokens)

        # Transform each frame image and stack them into a tensor
        pixel_values = [transform(image) for image in image_list]
        pixel_values = torch.stack(pixel_values)
        num_patches = pixel_values.size(0)

        # Select the appropriate preprocessing function based on the template name
        preprocess_function = self.get_preprocess_function()

        # Preprocess the conversations and generate the return dictionary
        num_image_tokens = [self.num_image_token] * num_patches
        ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
                                  self.tokenizer, num_image_tokens, group_by_length=self.group_by_length,
                                  ds_name=self.ds_name, num_image=num_patches)

        # Create the final return dictionary
        ret = dict(
            input_ids=ret['input_ids'][0],
            labels=ret['labels'][0],
            attention_mask=ret['attention_mask'][0],
            pixel_values=pixel_values,
            image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
        )
        return ret

    def pure_text_get_item(self, data_item):
        # Build transformation function
        transform = self.get_transform()

        # Create a blank white image
        image = Image.new('RGB', (224, 224), (255, 255, 255))

        # Dynamically preprocess the image to generate patches
        images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=1,
                                    image_size=self.image_size, use_thumbnail=self.use_thumbnail)

        # Apply the transformation to each image patch and stack them into a tensor
        pixel_values = [transform(image) for image in images]
        pixel_values = torch.stack(pixel_values)
        num_patches = pixel_values.size(0)

        # Ensure there is only one patch
        assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'

        # Select the appropriate preprocessing function based on the template name
        preprocess_function = self.get_preprocess_function()

        # Preprocess the conversations and generate the return dictionary
        ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
                                  self.tokenizer, [self.num_image_token * num_patches], text_only=True,
                                  group_by_length=self.group_by_length, ds_name=self.ds_name)

        # Create the final return dictionary
        ret = dict(
            input_ids=ret['input_ids'][0],
            labels=ret['labels'][0],
            attention_mask=ret['attention_mask'][0],
            pixel_values=pixel_values,
            image_flags=torch.tensor([0] * num_patches, dtype=torch.long)
        )
        return ret

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        i = i % len(self.raw_data)
        while True:
            try:
                data_item = json.loads(self.raw_data[i])
                if 'image' in data_item and len(data_item['image']) != 0:
                    if type(data_item['image']) == list:
                        ret = self.multi_modal_multi_image_get_item(data_item)
                    else:
                        ret = self.multi_modal_get_item(data_item)
                elif 'video' in data_item and data_item['video'] is not None and data_item['video'] != '':
                    ret = self.video_get_item(data_item)
                else:
                    ret = self.pure_text_get_item(data_item)
                break
            except Exception as e:
                print(e, self.ds_name, flush=True)
                if not isinstance(e, UnidentifiedImageError):
                    traceback.print_exc()
                data_item = json.loads(self.raw_data[i])
                if 'image' in data_item:
                    if type(data_item['image']) == list:
                        images = [self.root + item for item in data_item['image']]
                        print(f'Failed to load image: {images}, the dataset is: {self.ds_name}')
                    else:
                        if data_item['image'].startswith('s3://'):
                            data_path = self.root + data_item['image']
                        else:
                            data_path = os.path.join(self.root, data_item['image'])
                        print(f'Failed to load image: {data_path}, the dataset is: {self.ds_name}')
                elif 'video' in data_item:
                    data_path = os.path.join(self.root, data_item['video'])
                    print(f'Failed to load video: {data_path}, the dataset is: {self.ds_name}')
                i = random.randint(0, len(self.raw_data) - 1)
        return ret

In [None]:
def train_collate_fn(examples):

    # YOUR CODE HERE

    return input_ids, attention_mask, pixel_values_videos, labels


def eval_collate_fn(examples):

    # YOUR CODE HERE

    return input_ids, attention_mask, pixel_values_videos, answer_choice

## Shuffling and Splitting the Dataset
You need to shuffle dataset, and then split it into training and test sets. This ensures that our model is trained on a diverse and representative sample of the data.


In [None]:
# YOUR CODE HERE

In [None]:
%matplotlib inline

from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML


example = dataset['train'][0]
clip = example["clip"]
# np array with shape (frames, height, width, channels)
video = np.array(clip)


# Vusualize your data
# YOUR CODE HERE

In [None]:
fig = plt.figure()
im = plt.imshow(video[0,:,:,:])

plt.close() # this is required to not display the generated image

def init():
    im.set_data(video[0,:,:,:])

def animate(i):
    im.set_data(video[i,:,:,:])
    return im

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],
                               interval=100)
HTML(anim.to_html5_video())

And now wrap it in the Pytorch Datasets class and print one example as sanity check.

In [None]:
train_dataset = VideoLlavaDataset(dataset["train"])
eval_dataset = VideoLlavaDataset(dataset["test"])

In [None]:
prompt, clip = train_dataset[0]

# Model

## Load model
Next, load your InternVL model from the hub. This is a model with about 1 billion trainable parameters (as it combines a **Qwen2 1B language model** with a relatively low-parameter vision **InternViT encoder**). Do note that we load a model here which already has undergone supervised fine-tuning (SFT) instructions dataset. We can benefit from the fine-tuning that the model already has undergone.

## Full fine-tuning, LoRa and Q-LoRa

**Select the fine-tuning method.**

 For reference, fine-tuning a model using the AdamW optimizer (which is often used to optimize neural networks) with mixed precision, you need about 18 times the amount of parameters in GB of GPU RAM. So in this case, we would need 18x1 billion bytes = 18 GB of GPU RAM if we want to update all the parameters of the model. Not so huge right? But using PEFT approach it could be less.

Some clever people came up with the LoRa method (LoRa is short for low-rank adapation). It allows to just freeze the existing weights and only train a couple of adapter layers on top of the base model. Hugging Face offers the separate [PEFT library](https://huggingface.co/docs/peft/main/en/index) for easy use of LoRa, along with other Parameter-Efficient Fine-Tuning methods.

Moreover, one can not only freeze the existing base model but also quantize it (which means, shrinking down its size). A neural network's parameters are typically saved in either float32 (which means, 32 bits or 4 bytes are used to store each parameter value) or float16 (which means, 16 bits or half a byte - also called half precision). However, with some clever algorithms one can shrink each parameter to just 8 or 4 bits (half a byte!), without significant effect on final performance. Read all about it here: https://huggingface.co/blog/4bit-transformers-bitsandbytes.

This means that we're going to shrink the size of the base 1B model considerably using 4-bit quantization, and then only train a couple of adapter layers on top using LoRa (in float16). This idea of combining LoRa with quantization is called Q-LoRa and is the most memory friendly version.

There exist many forms of quantization, here we leverage the [BitsAndBytes integration](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig).

In [None]:
## Load model
# Three options for training, from the lowest precision training to the highest precision training:
# QLoRA: model uses 4-bit quantization, which helps in reducing memory usage while maintaining performance.
# Standard LoRA:  model is loaded with standard LoRA adaptations.
# Full Fine-Tuning: no memory optimization are done. In that case Flash Attention is used to speed up training, if hardware supports it.

# YOUR CODE HERE

In [None]:
def find_all_linear_names(model):
    # Only for LoRA ot QLoRA

    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['multi_modal_projector', 'vision_model']
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

# If you selected LoRA ot QLora make a choise of parameters to replace
# YOUR CODE HERE

# Then create LoraConfig and run prepare_model_for_kbit_training(...)
# and finally: model = get_peft_model(model, ...)
# YOUR CODE HERE

['up_proj', 'down_proj', 'q_proj', 'o_proj', 'k_proj', 'gate_proj', 'v_proj']

## Define PyTorch Lightning Module for Video-LLaVA
To streamline the training and evaluation of the Video-InternVL model, you can use [LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html), which abstracts away much of the boilerplate code and provides a structured framework for model training. In this section, you need to define the InternVLModelPLModule, a custom PyTorch Lightning module that encapsulates the model, training loop, validation loop, and optimizer configuration.

### InternVLModelPLModule Class

The InternVLModelPLModule class inherits from LightningModule and includes methods for training, validation, and optimizer configuration. This setup ensures a clean and efficient training process.

Basically, PyTorch Lightning will take care of all device placements (.to(device)) for us, as well as the backward pass, putting the model in training mode, etc.

Notice the difference between a training step and an evaluation step:

- a training step only consists of a forward pass, in which we compute the cross-entropy loss between the model's next token predictions and the ground truth (in parallel for all tokens, this technique is known as "teacher forcing"). The backward pass is handled by PyTorch Lightning.
- an evaluation step consists of making the model autoregressively complete the prompt using the generate() method. After that, you compute an evaluation metric between the predicted sequences and the ground truth ones. This allows you to see how the model is improving over the course of training. The metric we use here is accuracy of answering the question.

Besides that, you define the optimizer to use (AdamW is a good default choice) and the data loaders, which use the collate functions defined above to batch together items of the PyTorch datasets. Do note that AdamW is a pretty heavy optimizer in terms of memory requirements, but as we're training with QLoRa we only need to store optimizer states for the adapter layers. For full fine-tuning, one could take a look at more memory friendly optimizers such as 8-bit Adam.

In [None]:
class InternVLModelPLModule(L.LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        # YOUR CODE HERE

    def training_step(self, batch, batch_idx):
        # YOUR CODE HERE

    def validation_step(self, batch, batch_idx, dataset_idx=0):
    # YOUR CODE HERE

    def configure_optimizers(self):
        # YOUR CODE HERE

    def train_dataloader(self):
        return DataLoader(train_dataset, collate_fn=train_collate_fn,
                          batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(eval_dataset, collate_fn=eval_collate_fn,
                          batch_size=self.batch_size, shuffle=False, num_workers=4)

Then instantiate it (based on a config dictionary which defines all hyperparameters for training).

The batch size was determined based on the compute available.

Do note that one can play around with the hyperparameters, I just use good defaults here: 10 epochs, a learning rate of 1e-4, use mixed precision for training (more memory friendly). One could extend this with things like gradient accumulation and gradient checkpointing.

I recommend [this guide](https://huggingface.co/docs/transformers/v4.20.1/en/perf_train_gpu_one) which goes over all tips and tricks regarding maximizing fine-tuning performance on consumer hardware.

In [None]:
config = {"max_epochs": 2,
          # "val_check_interval": 0.2, # how many times we want to validate during an epoch
          "check_val_every_n_epoch": 1,
          "gradient_clip_val": 1.0,
          "accumulate_grad_batches": 8,
          "lr": 1e-4,
          "batch_size": 1,
          "num_nodes": 1,
          "warmup_steps": 50,
}
# Instantiate yout module here
# YOUR CODE HERE

## Define callbacks
Optionally, Lightning allows to define so-called [callbacks](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html), which are arbitrary pieces of code that can be executed during training.

You'd better use the EarlyStopping callback of Lightning, which will automatically stop training once the evaluation metric (edit distance in our case) doesn't improve after 3 epochs.

In [None]:
from huggingface_hub import HfApi
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="VideoLLava-demo")

api = HfApi()

class PushToHubCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
        pl_module.model.push_to_hub(REPO_ID,
                                    commit_message=f"Training in progress, epoch {trainer.current_epoch}")

    def on_train_end(self, trainer, pl_module):
        print(f"Pushing model to the hub after training")
        pl_module.processor.push_to_hub(REPO_ID,
                                    commit_message=f"Training done")
        pl_module.model.push_to_hub(REPO_ID,
                                    commit_message=f"Training done")

early_stop_callback = EarlyStopping(monitor="your-metric-here",
                                    patience=3, verbose=False, mode="max")

## Train!
 Trainer class supports many more flags. See the [docs](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html#lightning.pytorch.trainer.trainer.Trainer)

In [None]:
trainer = L.Trainer(
        # YOUR CODE HERE
)

In [None]:
trainer.fit(model_module)

In [None]:
# %load_ext tensorboard
# %reload_ext tensorboard
# %tensorboard --logdir .

## Inference

Let's see if the model has learned something. First load the model from the hub first.

In [None]:
from transformers import AutoProcessor, BitsAndBytesConfig, VideoLlavaForConditionalGeneration
import torch

# YOUR CODE HERE

See one example from the validation set here and plot 8 frames to see what is happening in the video.

In [None]:
from matplotlib import pyplot as plt
from PIL import Image

prompt, clip = eval_dataset[2]
fig, axarr = plt.subplots(1, 2, figsize = (10, 10))
fig.tight_layout()

for i in range(2):
    curr_frame = Image.fromarray(np.uint8(clip[i]))
    axarr[i].imshow(curr_frame)
    axarr[i].get_xaxis().set_visible(False)
    axarr[i].get_yaxis().set_visible(False)
    axarr[i].set_aspect('equal')

plt.subplots_adjust(wspace=None, hspace=None)
plt.axis('off')
plt.show()

Next you need to prepare the video for the model, along with the text prompt we used during training. You need to apply the chat template to make sure the format is respected.

Notice that this is exactly the same as what you did for the evaluation data collate function.

In [None]:
# YOUR CODE HERE

Next you should let the model autoregressively generate tokens using the [generate()](https://huggingface.co/docs/transformers/v4.40.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) method, which is recommended for use at inference time. This method feeds each predicted token back into the model as conditioning for each next time step.

Do note that there are various ways of decoding text, here we use greedy decoding which is the default. There are various fancier methods such as beam search and top-k sampling. Refer to [this amazing blog post](https://huggingface.co/blog/how-to-generate) for all details.

In [None]:
# Generate token IDs
# YOUR CODE HERE

# Decode back into text
# YOUR CODE HERE