## Evaluation on TVBench
This notebook evaluates VideoChat2 on the TVBench dataset. To run this code, first install VideoChat2 [dependencies](https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/requirements.txt). TVBench follows the same structure as MVBench, so any codebase with support for MVBench can be directly adapted to TVBench by simple updating the dataset path.


In [None]:
import json
import torch

from tqdm import tqdm

from vqa_dataset import VQADataset, check_ans
from vqa_model import build_videoChat2_infer, get_sinusoid_encoding_table, ask, answer
from utils.easydict import EasyDict
from utils.config import Config

## TVBench dataset paths
Download the TVBench dataset from [here](https://huggingface.co/datasets/FunAILab/TVBench).

Replace `/datasets/TVBench/` with the path to the TVBench folder on your system.

In [2]:
data_dir = "/datasets/TVBench/json"
data_list = {
    "Action Count": ("action_count.json", "/datasets/TVBench/video/action_count", "video", False),
    "Object Count": ("object_count.json", "/datasets/TVBench/video/object_count", "video", False),
    "Action Sequence": ("action_sequence.json", "/datasets/TVBench/video/action_sequence", "video", True),  # has start & end
    "Object Shuffle": ("object_shuffle.json", "/datasets/TVBench/video/object_shuffle", "video", False),
    "Scene Transition": ("scene_transition.json", "/datasets/TVBench/video/scene_transition", "video", False),
    "Action Localization": ("action_localization.json", "/datasets/TVBench/video/action_localization", "video", True),  # has start & end
    "Action Antonym": ("action_antonym.json", "/datasets/TVBench/video/action_antonym", "video", False),
    "Unexpected Action": ("unexpected_action.json", "/datasets/TVBench/video/unexpected_action", "video", False),
    "Egocentric Sequence": ("egocentric_sequence.json", "/datasets/TVBench/video/egocentric_sequence", "video", False),
    "Moving Direction": ("moving_direction.json", "/datasets/TVBench/video/moving_direction", "video", False),
}

## Create a VQA dataset

Create a VQADataset that will load the TVBench dataset defined above.

In [3]:
#  position embedding
num_frame = 16
resolution = 224

dataset = VQADataset(data_dir, data_list, num_segments=num_frame, resolution=resolution)

## Init VideoChat2
Download [VideoChat Stage 3](https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/scripts/videochat_vicuna/config_7b_stage3.py) and update `model_path`. Also update model paths on `configs/config.yaml` to the following files:

- `"videochat2_model_path"`: [VideoChat2](https://huggingface.co/OpenGVLab/videochat/resolve/main/videochat2_7b_stage2.pth)
- `"vit_blip_model_path"`: [VideoBLIP](https://huggingface.co/OpenGVLab/videochat/resolve/main/umt_l16_qformer.pth)
- `"llama_model_path"`: [LLM](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat#running-usage)

In [None]:
config_file = "configs/config.json"
cfg = Config.from_file(config_file)

num_frame = 16
resolution = 224
model = build_videoChat2_infer("/models/video_chat2/videochat2_7b_stage3.pth", cfg)
new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
model.vision_encoder.encoder.pos_embed = new_pos_emb

## VideoChat2 inference function
We keep the same inference function as in the original implementation, as methods evaluated on MVBench can be directly evaluated on TVBench following the same strategy.

In [5]:
def infer(
        data_sample, system="", 
        question_prompt='', # add in the end of question
        answer_prompt=None, # add in the begining of answer
        return_prompt='',  # add in the begining of return message
        system_q=False, # whether add question in the system prompt for QFormer
        print_res=True,
        system_llm=False
    ):
    video = data_sample["video"]
    TC, H, W = video.shape
    video = video.reshape(1, TC//3, 3, H, W).to("cuda:0")
    
    video_list = []
    with torch.no_grad():
        if system_q:
            video_emb, _ = model.encode_img(video, system + data_sample['question'])
        else:
            video_emb, _ = model.encode_img(video, system)
    video_list.append(video_emb)
#     video_list.append(torch.zeros_like(video_emb))

    chat = EasyDict({
        "system": system,
        "roles": ("Human", "Assistant"),
        "messages": [],
        "sep": "###"
    })

    chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video>\n"])
    
    if system_llm:
        prompt = system + data_sample['question'] + question_prompt
    else:
        prompt = data_sample['question'] + question_prompt
    
    ask(prompt, chat)

    llm_message = answer(
        conv=chat, model=model, do_sample=False, 
        img_list=video_list, max_new_tokens=100, 
        answer_prompt=answer_prompt, print_res=print_res
    )[0]
    # remove potential explanation
    llm_message = return_prompt + llm_message.strip().split('\n')[0]
    print(llm_message)
    print(f"GT: {data_sample['answer']}")
    return llm_message

## Inference loop
We also keep the same implementation of the inference loop.

In [None]:
save_path = "./test"

correct = 0
total = 0
res_list = []
acc_dict = {}

for example in tqdm(dataset):
    task_type = example['task_type']
    if task_type not in acc_dict:
        acc_dict[task_type] = [0, 0] # correct, total
    acc_dict[task_type][1] += 1
    total += 1
    pred = infer(
        example, 
        system="Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n",
        question_prompt="\nOnly give the best option.",
        answer_prompt="Best option:(",
        return_prompt='(',
        system_q=False,
        print_res=True,
        system_llm=True
    )
    gt = example['answer']
    res_list.append({
        'pred': pred,
        'gt': gt
    })
    if check_ans(pred=pred, gt=gt):
        acc_dict[task_type][0] += 1
        correct += 1
    print(f"Part  Acc: {acc_dict[task_type][0] / acc_dict[task_type][1] * 100 :.2f}%")
    print(f"Total Acc: {correct / total * 100 :.2f}%")
    print('-' * 30, task_type, '-' * 30)

with open(f"{save_path}.json", "w") as f:
    json.dump({
        "acc_dict": acc_dict,
        "res_list": res_list
    }, f)