In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import os
initial_directory = os.getcwd()
print("Initial Directory:", initial_directory)
os.chdir('../')

Initial Directory: /root/code/LLM_pipline/video_chat2/demo


In [2]:
from utils.config import Config
config_file = "configs/config_mistral_hd.json"
cfg = Config.from_file(config_file)

In [3]:
import io

from models import VideoChat2_it_hd_mistral
from utils.easydict import EasyDict
import torch

from transformers import StoppingCriteria, StoppingCriteriaList

from PIL import Image
import numpy as np
import numpy as np
from decord import VideoReader, cpu
import torchvision.transforms as T
from torchvision.transforms import PILToTensor
from torchvision import transforms
from dataset.video_transforms import (
    GroupNormalize, GroupScale, GroupCenterCrop, 
    Stack, ToTorchFormatTensor
)
from torch.utils.data import Dataset
from torchvision.transforms.functional import InterpolationMode

from torchvision import transforms

import matplotlib.pyplot as plt

from IPython.display import Video, HTML

from peft import get_peft_model, LoraConfig, TaskType
import copy

import json
from collections import OrderedDict

from tqdm import tqdm

import decord
import time
decord.bridge.set_bridge("torch")

In [4]:
# load stage2 model
cfg.model.vision_encoder.num_frames = 16
model = VideoChat2_it_hd_mistral(config=cfg.model)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
model

VideoChat2_it_hd_mistral(
  (vision_encoder): PretrainVisionTransformer(
    (encoder): PretrainVisionTransformerEncoder(
      (patch_embed): PatchEmbed(
        (proj): Conv3d(3, 1024, kernel_size=(1, 16, 16), stride=(1, 16, 16))
      )
      (blocks): ModuleList(
        (0): Block(
          (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=False)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=4096, out_features=1024, bias=True)
            (dr

In [5]:
def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    size_all_mb = (param_size + buffer_size) / (1024 ** 2)
    return size_all_mb

# Assuming lora_model is your model
model_size = get_model_size(model)
print(f"模型大小: {model_size:.2f} MB")

模型大小: 16139.64 MB


In [6]:
# add lora to run stage3 model
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=False, 
    r=16, lora_alpha=32, lora_dropout=0.,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
         "gate_proj", "up_proj", "down_proj", "lm_head"
    ]
)
model.mistral_model = get_peft_model(model.mistral_model, peft_config)

In [7]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load("/vepfs/fs_users/lkn/videochat2/videochat2_hd_mistral_7b_stage4.pth", map_location=device)

if 'model' in state_dict.keys():
    msg = model.load_state_dict(state_dict['model'], strict=False)
else:
    msg = model.load_state_dict(state_dict, strict=False)
print(msg)

model = model.to(torch.device(cfg.device))
model = model.eval()

_IncompatibleKeys(missing_keys=['mistral_model.base_model.model.model.embed_tokens.weight', 'mistral_model.base_model.model.model.layers.0.self_attn.q_proj.weight', 'mistral_model.base_model.model.model.layers.0.self_attn.k_proj.weight', 'mistral_model.base_model.model.model.layers.0.self_attn.v_proj.weight', 'mistral_model.base_model.model.model.layers.0.self_attn.o_proj.weight', 'mistral_model.base_model.model.model.layers.0.mlp.gate_proj.weight', 'mistral_model.base_model.model.model.layers.0.mlp.up_proj.weight', 'mistral_model.base_model.model.model.layers.0.mlp.down_proj.weight', 'mistral_model.base_model.model.model.layers.0.input_layernorm.weight', 'mistral_model.base_model.model.model.layers.0.post_attention_layernorm.weight', 'mistral_model.base_model.model.model.layers.1.self_attn.q_proj.weight', 'mistral_model.base_model.model.model.layers.1.self_attn.k_proj.weight', 'mistral_model.base_model.model.model.layers.1.self_attn.v_proj.weight', 'mistral_model.base_model.model.mode

In [8]:
def get_prompt(conv):
    ret = conv.system + conv.sep
    for role, message in conv.messages:
        if message:
            ret += role + " " + message + " " + conv.sep
        else:
            ret += role
    return ret


def get_prompt2(conv):
    ret = conv.system + conv.sep
    count = 0
    for role, message in conv.messages:
        count += 1
        if count == len(conv.messages):
            ret += role + " " + message
        else:
            if message:
                ret += role + " " + message + " " + conv.sep
            else:
                ret += role
    return ret


def get_context_emb(conv, model, img_list, answer_prompt=None, print_res=False):
    if answer_prompt:
        prompt = get_prompt2(conv)
    else:
        prompt = get_prompt(conv)
    if print_res:
        print(prompt)
    if '<VideoHere>' in prompt:
        prompt_segs = prompt.split('<VideoHere>')
    else:
        prompt_segs = prompt.split('<ImageHere>')
    assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
    with torch.no_grad():
        seg_tokens = [
            model.mistral_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to("cuda:0").input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_embs = [model.mistral_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
#         seg_embs = [model.mistral_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
    mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
    mixed_embs = torch.cat(mixed_embs, dim=1)
    return mixed_embs


def ask(text, conv):
    conv.messages.append([conv.roles[0], text])
        

class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False
    
    
def answer(conv, model, img_list, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
               repetition_penalty=1.0, length_penalty=1, temperature=1.0, answer_prompt=None, print_res=False):
    stop_words_ids = [
        torch.tensor([2]).to("cuda:0"),
        torch.tensor([29871, 2]).to("cuda:0")]  # '</s>' can be encoded in two different ways.
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
    
    conv.messages.append([conv.roles[1], answer_prompt])
    embs = get_context_emb(conv, model, img_list, answer_prompt=answer_prompt, print_res=print_res)
    with torch.no_grad():
        outputs = model.mistral_model.generate(
            inputs_embeds=embs,
            max_new_tokens=max_new_tokens,
            stopping_criteria=stopping_criteria,
            num_beams=num_beams,
            do_sample=do_sample,
            min_length=min_length,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            temperature=temperature,
        )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
            output_token = output_token[1:]
    if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
            output_token = output_token[1:]
    output_text = model.mistral_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('</s>')[0]  # remove the stop sign </s>
#     output_text = output_text.split('[/INST]')[-1].strip()
    conv.messages[-1][1] = output_text + '</s>'
    return output_text, output_token.cpu().numpy()

In [9]:
from dataset.hd_utils import HD_transform_padding, HD_transform_no_padding

def get_index(num_frames, num_segments):
    seg_size = float(num_frames - 1) / num_segments
    start = int(seg_size / 2)
    offsets = np.array([
        start + int(np.round(seg_size * idx)) for idx in range(num_segments)
    ])
    return offsets

def load_video(video_path, num_segments=8, return_msg=False, resolution=224, hd_num=6, padding=False):
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    num_frames = len(vr)
    frame_indices = get_index(num_frames, num_segments)

    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    transform = transforms.Compose([
        transforms.Lambda(lambda x: x.float().div(255.0)),
        transforms.Normalize(mean, std)
    ])

    frames = vr.get_batch(frame_indices)
    frames = frames.permute(0, 3, 1, 2)

    if padding:
        frames = HD_transform_padding(frames.float(), image_size=resolution, hd_num=hd_num)
    else:
        frames = HD_transform_no_padding(frames.float(), image_size=resolution, hd_num=hd_num)

    frames = transform(frames)
    print(frames.shape)
    
    if return_msg:
        fps = float(vr.get_avg_fps())
        sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
        # " " should be added in the start and end
        msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
        return frames, msg
    else:
        return frames
    

def load_video_by_segments(video_path, segment_duration=10, num_segments=8, return_msg=False, resolution=224, hd_num=6, padding=False):
    """先分割长视频，然后从每个分割后的片段中提取帧"""
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    num_frames = len(vr)
    fps = vr.get_avg_fps()
    total_duration = num_frames / fps
    
    # 计算总的分段数量
    num_total_segments = int(total_duration // segment_duration)
    
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    transform = transforms.Compose([
        transforms.Lambda(lambda x: x.float().div(255.0)),
        transforms.Normalize(mean, std)
    ])
    
    frames_list = []
    msg_list = []
    
    # 按时间段分割视频
    for i in range(num_total_segments):
        # print(i)
        start_time = i * segment_duration
        end_time = start_time + segment_duration
        start_frame = int(start_time * fps)
        end_frame = int(end_time * fps)
        
        # 确保索引不会超出视频的总帧数
        frame_indices = get_index(end_frame - start_frame, num_segments) + start_frame
        
        if len(frame_indices) > 0:
            frames = vr.get_batch(frame_indices)
            frames = frames.permute(0, 3, 1, 2)

            if padding:
                frames = HD_transform_padding(frames.float(), image_size=resolution, hd_num=hd_num)
            else:
                frames = HD_transform_no_padding(frames.float(), image_size=resolution, hd_num=hd_num)

            frames = transform(frames)
            frames_list.append(frames)
            
            if return_msg:
                sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
                msg = f"视频的第 {i + 1} 段包含 {len(frame_indices)} 帧，采样时间为 {sec} 秒。"
                msg_list.append(msg)
    
    if return_msg:
        return frames_list, msg_list
    else:
        return frames_list

In [11]:
def get_sinusoid_encoding_table(n_position=784, d_hid=1024, cur_frame=8, ckpt_num_frame=4, pre_n_position=784): 
    ''' Sinusoid position encoding table ''' 
    # TODO: make it with torch instead of numpy 
    def get_position_angle_vec(position): 
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 
    
    # generate checkpoint position embedding
    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)]) 
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 
    sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
    
    print(f"n_position: {n_position}")
    print(f"pre_n_position: {pre_n_position}")
    
    if n_position != pre_n_position:
        T = ckpt_num_frame # checkpoint frame
        P = 14 # checkpoint size
        C = d_hid
        new_P = int((n_position // cur_frame) ** 0.5) # testing size
        if new_P != 14:
            print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}')
            print(f'Interpolate the position embedding')
            sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
            sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2)
            sinusoid_table = torch.nn.functional.interpolate(
                sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False)
            # BT, C, H, W -> BT, H, W, C ->  B, T, H, W, C
            sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C)
            sinusoid_table = sinusoid_table.flatten(1, 3)  # B, THW, C
    
    if cur_frame != ckpt_num_frame:
        print(f'Pretraining uses 4 frames, but current frame is {cur_frame}')
        print(f'Interpolate the position embedding')
        T = ckpt_num_frame # checkpoint frame
        new_T = cur_frame # testing frame
        # interpolate
        P = int((n_position // cur_frame) ** 0.5) # testing size
        C = d_hid
        sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
        sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T)  # BHW, C, T
        sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
        sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
        sinusoid_table = sinusoid_table.flatten(1, 3)  # B, THW, C
        
    return sinusoid_table

In [12]:
# vid_path = "/root/code/LLM_pipline/video_chat2/example/test.mp4"
# # vid_path = "./demo/example/jesse_dance.mp4"


# num_frame = 8
# # num_frame = 16
# # resolution = 384
# resolution = 224
# # hd_num = 6
# hd_num = 12
# padding = False
# vid, msg = load_video(
#     vid_path, num_segments=num_frame, return_msg=True, resolution=resolution,
#     hd_num=hd_num, padding=padding
# )
# new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
# model.vision_encoder.encoder.pos_embed = new_pos_emb

# print(msg)
    
# # The model expects inputs of shape: T x C x H x W
# T_, C, H, W = vid.shape
# video = vid.reshape(1, T_, C, H, W).to("cuda:0")

# img_list = []
# with torch.no_grad():
#     image_emb, _, _ = model.encode_img(video, "Watch the video and answer the question.")
# #     image_emb, _, _ = model.encode_img(video, "")

# img_list.append(image_emb[0])

# HTML(f'<video alt="test" controls><source src="{vid_path}" type="video/mp4"></video>')

torch.Size([8, 3, 672, 896])
n_position: 1568
pre_n_position: 784
Pretraining uses 4 frames, but current frame is 8
Interpolate the position embedding
The video contains 8 frames sampled at 0.4, 1.3, 2.3, 3.2, 4.1, 5.0, 5.9, 6.8 seconds.


In [19]:
vid_path = "/root/code/LLM_pipline/video_chat2/example/trump_long.mp4"
# vid_path = "./demo/example/jesse_dance.mp4"

num_frame = 10
resolution = 224
hd_num = 12
padding = False
segment_duration = 50  # 设定固定的时间段长度（单位为秒）

# 使用你自己的逻辑，先按固定时间段分割视频，再提取帧
vid_segments, msg_segments = load_video_by_segments(
    vid_path, segment_duration=segment_duration, num_segments=num_frame, return_msg=True,
    resolution=resolution, hd_num=hd_num, padding=padding
)
print("length of segments is: ",len(msg_segments))
# 将所有分段的信息进行汇总输出
# for i, msg in enumerate(msg_segments):
#     print(f"Segment {i+1}: {msg}")

# 生成位置编码表
new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution // 16) ** 2 * num_frame, cur_frame=num_frame)

# 设置模型中的位置编码
model.vision_encoder.encoder.pos_embed = new_pos_emb

# 初始化列表来存储每段的image embeddings
# img_list = []
out_list = []
# 遍历所有的分段视频帧

for vid in tqdm(vid_segments, desc="Processing video segments"):
    
    # 确保视频数据形状为 T x C x H x W
    T_, C, H, W = vid.shape
    video = vid.reshape(1, T_, C, H, W).to("cuda:0") 
    # print("after",video.device)
    # 使用模型编码每段视频帧
    with torch.no_grad():
        image_emb, _, _ = model.encode_img(video, "Watch the video and answer the question.")
    
    # 将生成的embedding添加到列表中
    # img_list.append(image_emb[0])
    torch.cuda.empty_cache()
    chat = EasyDict({
    "system": "",
    "roles": ("[INST]", "[/INST]"),
    "messages": [],
    "sep": ""
    })
    itr_list = []
    chat.messages.append([chat.roles[0], "<Video><VideoHere></Video> [/INST] Summarize the main story or event of the video. What is happening?"])
    llm_message_1 = answer(conv=chat, model=model, do_sample=False, img_list=[image_emb[0]], max_new_tokens=1024, print_res=True)[0]
    itr_list.append(llm_message_1)
    
    chat.messages.append([chat.roles[0], "Now describe the main characters in the video. Include details about their clothing or physical appearance."])
    llm_message_2 = answer(conv=chat, model=model, do_sample=False, img_list=[image_emb[0]], max_new_tokens=1024, print_res=True)[0]
    itr_list.append(llm_message_2)
    
    chat.messages.append([chat.roles[0], "Describe the objects and background elements, including the environment or any significant objects in the video."])
    llm_message_3 = answer(conv=chat, model=model, do_sample=False, img_list=[image_emb[0]], max_new_tokens=1024, print_res=True)[0]
    itr_list.append(llm_message_3)

    out_list.append(itr_list)

    ###  model.asnwer......
    ##  清内存一下
out_list
torch.cuda.empty_cache()

# 显示视频
# HTML(f'<video alt="test" controls><source src="{vid_path}" type="video/mp4"></video>')


length of segments is:  3
n_position: 2352
pre_n_position: 784
Pretraining uses 4 frames, but current frame is 12
Interpolate the position embedding


Processing video segments:   0%|          | 0/3 [00:00<?, ?it/s]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[INST] <Video><VideoHere></Video> [/INST] Summarize the main story or event of the video. What is happening? [/INST]


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[INST] <Video><VideoHere></Video> [/INST] Summarize the main story or event of the video. What is happening? [/INST] The video captures a political rally where a speaker, presumably a politician, is addressing a crowd. The speaker is dressed in a dark suit and a red cap with the word "MAGA" printed on it. The crowd is diverse, with individuals wearing various hats, some with the same "MAGA" slogan, and others with different slogans. The speaker is gesturing with his hands and appears to be speaking passionately. The background is filled with a sea of people, some holding up signs, and the atmosphere is lively and energetic. </s> [INST] Now describe the main characters in the video. Include details about their clothing or physical appearance. [/INST]


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[INST] <Video><VideoHere></Video> [/INST] Summarize the main story or event of the video. What is happening? [/INST] The video captures a political rally where a speaker, presumably a politician, is addressing a crowd. The speaker is dressed in a dark suit and a red cap with the word "MAGA" printed on it. The crowd is diverse, with individuals wearing various hats, some with the same "MAGA" slogan, and others with different slogans. The speaker is gesturing with his hands and appears to be speaking passionately. The background is filled with a sea of people, some holding up signs, and the atmosphere is lively and energetic. </s> [INST] Now describe the main characters in the video. Include details about their clothing or physical appearance. [/INST] The main character in the video is the speaker, who is wearing a dark suit and a red cap with the word "MAGA" printed on it. The crowd consists of individuals wearing various hats, some with the same "MAGA" slogan, and others with different

Processing video segments:  33%|███▎      | 1/3 [00:10<00:21, 10.60s/it]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[INST] <Video><VideoHere></Video> [/INST] Summarize the main story or event of the video. What is happening? [/INST]


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[INST] <Video><VideoHere></Video> [/INST] Summarize the main story or event of the video. What is happening? [/INST] The video captures a political rally with a crowd of people, some holding signs, and a stage with a podium where a speaker is addressing the audience. </s> [INST] Now describe the main characters in the video. Include details about their clothing or physical appearance. [/INST]


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[INST] <Video><VideoHere></Video> [/INST] Summarize the main story or event of the video. What is happening? [/INST] The video captures a political rally with a crowd of people, some holding signs, and a stage with a podium where a speaker is addressing the audience. </s> [INST] Now describe the main characters in the video. Include details about their clothing or physical appearance. [/INST] The main characters in the video are the speaker and the audience. The speaker is dressed in a white shirt and a red cap, while the audience is wearing a variety of clothing, including red caps and shirts. </s> [INST] Describe the objects and background elements, including the environment or any significant objects in the video. [/INST]


Processing video segments:  67%|██████▋   | 2/3 [00:15<00:07,  7.89s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.31 GiB (GPU 0; 23.65 GiB total capacity; 18.84 GiB already allocated; 1.53 GiB free; 21.31 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [17]:
out_list

[['The video shows a person, presumably a politician, addressing a crowd. The individual is wearing a red cap with the word "TRUMP" on it and is holding a microphone, suggesting they are delivering a speech. The crowd is composed of people wearing red caps and holding signs, indicating a political rally or event. ',
  'The main character in the video is a person wearing a red cap with the word "TRUMP" on it. They are also wearing a dark suit and are holding a microphone, suggesting they are delivering a speech. The crowd in the background is composed of people wearing red caps and holding signs, indicating a political rally or event. ',
  'The background of the video is a crowd of people, many of whom are wearing red caps and holding signs. The environment suggests an outdoor setting, possibly a rally or event. The person in the foreground is holding a microphone, indicating they are addressing the crowd. '],
 ['The video shows a person, presumably a politician, addressing a crowd. The

In [13]:
chat = EasyDict({
    "system": "",
    "roles": ("[INST]", "[/INST]"),
    "messages": [],
    "sep": ""
})

chat.messages.append([chat.roles[0], "<Video><VideoHere></Video> [/INST]"])
# chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> {msg} [/INST]"])
ask("Describe the video in the format:1.Introduction, 2.character, 3.objects", chat)
#必要的生成，背景信息
llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=512, print_res=True)[0]
print(llm_message)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[INST] <Video><VideoHere></Video> [/INST] [INST] Describe the video in the format:1.Introduction, 2.character, 3.objects [/INST]
1. Introduction: The video opens with a blurry shot of a street at night, with rain pouring down. 2. Character: A man dressed in a dark suit and tie is seen holding an umbrella. 3. Objects: The man is holding a black umbrella and is standing in front of a storefront with the words "First Edition" visible. 


In [None]:
# img_path = "./demo/example/run.jpg"
img_path = "./demo/example/dog.png"
img = Image.open(img_path).convert('RGB')

plt.imshow(img)

resolution = 224
# resolution = 384
# hd_num = 6
hd_num = 12
padding = False

new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2, cur_frame=1, ckpt_num_frame=1, pre_n_position=14*14)
model.vision_encoder.encoder.img_pos_embed = new_pos_emb

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

transform = transforms.Compose([
    transforms.Lambda(lambda x: x.float().div(255.0)),
    transforms.Normalize(mean, std)
])
img = PILToTensor()(img).unsqueeze(0)

if padding:
    img = HD_transform_padding(img.float(), image_size=resolution, hd_num=hd_num)
else:
    img = HD_transform_no_padding(img.float(), image_size=resolution, hd_num=hd_num)
    
img = transform(img).unsqueeze(0).cuda()
print(img.shape)

img_list = []
with torch.no_grad():
#     image_emb, _, _ = model.encode_img(img, "")
    image_emb, _, _ = model.encode_img(img, "Observe the image and answer the question.")
img_list.append(image_emb[0])

In [None]:
chat = EasyDict({
    "system": "",
    "roles": ("[INST]", "[/INST]"),
    "messages": [],
    "sep": ""
})

chat.messages.append([chat.roles[0], f"<Image><ImageHere></Image> [/INST]"])
ask("Describe the following image in details.", chat)

llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True,)[0]
print(llm_message)

In [None]:
def check_answer_egoschema(pred, qid):
    correct = 0
    answer_content = ans_dict[qid]['content'].lower()
    if answer_content[-1] == ".":
        answer_content = answer_content[:-1]
    if ans_dict[qid]['answer'].lower() in pred.lower():
        flag = True
        for kk in ["(A)", "(B)", "(C)", "(D)", "(E)"]:
            if kk != ans_dict[qid]['answer'].lower() and kk in pred.lower():
                flag = ans_dict
                break
        if flag:
            correct += 1
    elif answer_content in pred.lower():
        correct = 1
    elif answer_content.replace("a ", "") in pred.lower():
        correct = 1
    elif answer_content.replace("an ", "") in pred.lower():
        correct = 1
    return correct

def infer_egoschema(
        data_sample, system="", 
        question_prompt='', # add in the end of question
        answer_prompt=None, # add in the begining of answer
        return_prompt='',  # add in the begining of return message
        system_q=False, # whether add question in the system prompt for QFormer
        print_res=True,
        system_llm=False,
        num_segments=8,
    ):
    vid_path = os.path.join("your_data_path/egoschema/videos", data_sample['video'])
    video, _ = load_video(vid_path, num_segments=num_segments, return_msg=True)
    T_, C, H, W = video.shape
    video = video.reshape(1, T_, C, H, W).to("cuda:0")
    
    video_list = []
    with torch.no_grad():
        if system_q:
            video_emb, _, _ = model.encode_img(video, system + data_sample['question'])
        else:
            video_emb, _, _ = model.encode_img(video, system)
    video_list.append(video_emb[0])

    chat = EasyDict({
        "system": system,
        "roles": ("[INST]", "[/INST]"),
        "messages": [],
        "sep": ""
    })

    chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> [/INST]"])
    
    if system_llm:
        prompt = system + data_sample['QA'][0]['q'] + question_prompt
    else:
        prompt = data_sample['QA'][0]['q'] + question_prompt
    
    ask(prompt, chat)

    llm_message = answer(
        conv=chat, model=model, do_sample=False, 
        img_list=video_list, max_new_tokens=100, 
        answer_prompt=answer_prompt, print_res=print_res
    )[0]
    # remove potential explanation
    llm_message = return_prompt + llm_message.strip().split('\n')[0]
    print(llm_message)
    print(f"GT: {data_sample['QA'][0]['a']}")
    return llm_message


import csv
# You can find the csv files in https://github.com/imagegridworth/IG-VLM/blob/main/data/multiple_choice_qa/EgoSchema.csv
with open("your_data_path/EgoSchema.csv", mode='r', encoding='utf-8') as file:
    reader = csv.reader(file)

    json_data = []
    ans_dict = {}
    
    for idx, msg in enumerate(reader):
        if idx == 0:
            print(msg)
            continue
            
        video = msg[1] + '.mp4'
        input_str = f"Question: {msg[3].capitalize()}\nOptions:\n"
    
        target_index = -1
        for i, candidate in enumerate(msg[5:]):
            option = chr(ord('A') + i)
            input_str += f"({option}) {candidate}\n"
            if candidate == msg[4]:
                target_index = i
            
        assert target_index != -1
        correct = chr(ord('A') + target_index)
        
        json_data.append({
            'video': video,
            "QA": [{
                "i": "",
                "q": input_str.strip(),
                "a": f"Answer: ({correct}) {msg[4]}",
            }]
        })

        ans_dict[idx - 1] = {
            'video': video,
            'answer': f"({correct})",
            'content': msg[4],
        }


#  position embedding
num_frame = 16
resolution = 224
new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
model.vision_encoder.encoder.pos_embed = new_pos_emb

correct = 0
total = 0
total_num = len(json_data)

output = ""

for idx, example in enumerate(tqdm(json_data)):
    start = time.time()
    llm_message = infer_egoschema(
        example, 
        "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n", 
        question_prompt="\nOnly give the best option.", 
        answer_prompt="Best option:(",
        return_prompt='(',
        system_q=False,
        print_res=False,
        system_llm=False,
        num_segments=16
    )
    
    duration = time.time() - start
    output += (example["video"] + '\n')
    output += (llm_message + '\n')
    correct += check_answer_egoschema(llm_message, idx)
    total += 1
    print("Acc:", correct / total)
    print('-' * 20, f'{idx+1}/{total_num} done,', f'cost: {duration:.2f}s', '-' * 20)

with open("./demo/egoschema/your_prediction.txt", "w") as f:
    f.writelines(output)

In [None]:
# You can find the csv files in https://github.com/egoschema/EgoSchema/blob/main/questions.json
with open("your_data_path/EgoSchema/questions.json", "r") as f:
    full_data = json.load(f)

full_egoschema = []
for data in full_data:
    video = data['q_uid'] + '.mp4'
    input_str = f"Question: {data['question'].capitalize()}\nOptions:\n"

    for i, candidate in enumerate(['option 0', 'option 1', 'option 2', 'option 3', 'option 4']):
        option = chr(ord('A') + i)
        input_str += f"({option}) {data[candidate]}\n"
    
    full_egoschema.append({
        'q_uid': data['q_uid'],
        'video': video,
        "QA": [{
            "i": "",
            "q": input_str.strip(),
            "a": "",
        }]
    })


def infer_full_egoschema(
        data_sample, system="", 
        question_prompt='', # add in the end of question
        answer_prompt=None, # add in the begining of answer
        return_prompt='',  # add in the begining of return message
        system_q=False, # whether add question in the system prompt for QFormer
        print_res=True,
        system_llm=False,
        num_segments=8,
    ):
    vid_path = os.path.join("your_data_path/egoschema/videos", data_sample['video'])
    print(vid_path)
    video, _ = load_video(vid_path, num_segments=num_segments, return_msg=True)
    T_, C, H, W = video.shape
    video = video.reshape(1, T_, C, H, W).to("cuda:0")
    
    video_list = []
    with torch.no_grad():
        if system_q:
            video_emb, _, _ = model.encode_img(video, system + data_sample['question'])
        else:
            video_emb, _, _ = model.encode_img(video, system)
    video_list.append(video_emb[0])

    chat = EasyDict({
        "system": system,
        "roles": ("[INST]", "[/INST]"),
        "messages": [],
        "sep": ""
    })

    chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> [/INST]"])
    
    if system_llm:
        prompt = system + data_sample['QA'][0]['q'] + question_prompt
    else:
        prompt = data_sample['QA'][0]['q'] + question_prompt
    
    ask(prompt, chat)

    llm_message = answer(
        conv=chat, model=model, do_sample=False, 
        img_list=video_list, max_new_tokens=100, 
        answer_prompt=answer_prompt, print_res=print_res
    )[0]
    # remove potential explanation
    llm_message = return_prompt + llm_message.strip().split('\n')[0]
    print(llm_message)
    return llm_message


#  position embedding
num_frame = 16
resolution = 224
new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
model.vision_encoder.encoder.pos_embed = new_pos_emb


ans_dict = {}

for idx, example in enumerate(tqdm(full_egoschema)):
    start = time.time()
    llm_message = infer_full_egoschema(
        example, 
        "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n", 
        question_prompt="\nOnly give the best option.", 
        answer_prompt="Best option:(",
        return_prompt='(',
        system_q=False,
        print_res=False,
        system_llm=False,
        num_segments=16,
    )

    assert llm_message[0] == '(' and llm_message[2] == ')'
    ans = ord(llm_message[1]) - ord('A')
    assert ans in [0, 1, 2, 3, 4]
    ans_dict[example['q_uid']] = ans


with open("./demo/egoschema/your_prediction.json", "w") as f:
    json.dump(ans_dict, f)

# Then you can run https://github.com/egoschema/EgoSchema/blob/main/validate.py to get the score
# python3 validate.py --f ./your_prediction.json

In [None]:
import webvtt
import re

def clean_text(text):
    cleaned_text = re.sub(r'[^A-Za-z0-9\s]', '', text)
    return cleaned_text


def read_vtt_and_concatenate(file_path, tokenizer, max_len=4096):
    prev = ""
    subtitles = []
    for caption in webvtt.read(file_path):
        # Split the caption text into individual lines
        lines = caption.text.split('\n')
        for line in lines:
            # Clean the text and check for repetition
            line = clean_text(line)
            if prev != line and line:
                subtitles.append(line)
                prev = line

    # Join subtitles to check length
    full_text = ' '.join(subtitles)
    tokenized_ids = tokenizer(full_text, add_special_tokens=False).input_ids

    # If the tokenized length is within the limit, return the full text
    if len(tokenized_ids) <= max_len:
        return full_text

    # Otherwise, we need to trim the text to fit within the limit
    # We will keep the first half and the last half
    half_len = max_len // 2
    start_text = ' '.join(subtitles[:half_len])
    end_text = ' '.join(subtitles[-half_len:])
    
    # Re-tokenize to ensure the total length is within the limit
    start_tokenized_ids = tokenizer(start_text, add_special_tokens=False).input_ids
    end_tokenized_ids = tokenizer(end_text, add_special_tokens=False).input_ids

    # Adjust the lengths to fit within the max_len
    while len(start_tokenized_ids) + len(end_tokenized_ids) > max_len:
        if len(start_tokenized_ids) > len(end_tokenized_ids):
            start_tokenized_ids.pop()
        else:
            end_tokenized_ids.pop(0)
    
    # Combine the adjusted parts
    adjusted_text = tokenizer.decode(start_tokenized_ids) + ' ... ' + tokenizer.decode(end_tokenized_ids)
    
    return adjusted_text

    
class MME_dataset(Dataset):
    def __init__(self, data_prefix, anno_path, num_segments=16, resolution=224, hd_num=6, max_subtitle_len=4096):
        self.data_prefix = data_prefix
        with open(anno_path, 'r') as f:
            self.data_list = json.load(f)
        
        self.hd_num = hd_num
        self.num_segments = num_segments
        self.resolution = resolution
        self.max_subtitle_len = max_subtitle_len

        # transform
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        self.transform = transforms.Compose([
            transforms.Lambda(lambda x: x.float().div(255.0)),
            transforms.Normalize(mean, std)
        ])
        self.hd_transform = HD_transform_no_padding
    
    def __str__(self):
        task_dict = {}
        total = 0
        for data in self.data_list:
            if data['duration_category'] not in ans_dict:
                task_dict[data['duration_category']] = {}
            for q in data['questions']:
                if q['task_type'] not in ans_dict[data['duration_category']]:
                    ans_dict[data['duration_category']][q['task_type']] = 0
                ans_dict[data['duration_category']][q['task_type']] += 1
                total += 1

        res = f"There are {len(self.data_list)} videos.\n"
        res += f"There are {total} QAs.\n"
        for k, v in task_dict.items():
            res += f"------{k}------\n"
            for kk, vv in task_dict.items():
                res += f"{kk}: {vv}\n"
                
        return res.rstrip()
        
    def __len__(self):
        return len(self.data_list)
    
    def get_index(self, bound, fps, max_frame, first_idx=0):
        if bound:
            start, end = bound[0], bound[1]
        else:
            start, end = -100000, 100000
        start_idx = max(first_idx, round(start * fps))
        end_idx = min(round(end * fps), max_frame)
        seg_size = float(end_idx - start_idx) / self.num_segments
        frame_indices = np.array([
            int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
            for idx in range(self.num_segments)
        ])
        return frame_indices

    def read_frame(self, video_path, bound=None):
        video_path = os.path.join(video_path, str(self.num_segments))
        
        if os.path.exists(video_path):
            frame_list = [p for p in os.listdir(video_path)]
        else:
            raise Exception
            
        images_group = list()
        
        for frame_name in frame_list:
            img = Image.open(os.path.join(video_path, frame_name))
            img = PILToTensor()(img).unsqueeze(0)
            img = self.hd_transform(img.float(), image_size=self.resolution, hd_num=self.hd_num)
            images_group.append(img)
        torch_imgs = self.transform(torch.vstack(images_group))
        return torch_imgs
    
    def read_video(self, video_path, bound=None):
        vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
        max_frame = len(vr) - 1
        fps = float(vr.get_avg_fps())
        frame_indices = self.get_index(bound, fps, max_frame, first_idx=0) 

        frames = vr.get_batch(frame_indices)
        frames = frames.permute(0, 3, 1, 2)
        frames = self.hd_transform(frames.float(), image_size=self.resolution, hd_num=self.hd_num)
        torch_imgs = self.transform(frames)
        return torch_imgs

    def qa_template(self, data):
        question = f"Question: {data['question']}\n"
        question += "Options:\n"
        answer = data['answer']
        answer = f"({answer}) {data['choices'][ord(answer) - ord('A')][3:]}"
        for idx, c in enumerate(data['choices']):
            cur_choice, cur_text = c[0], c[3:]
            question += f"({cur_choice}) {cur_text}\n"
        question = question.rstrip()
        return question, answer

    def __getitem__(self, idx):
        video_name = self.data_list[idx]['url'].split("watch?v=")[1]
        video_path = os.path.join(self.data_prefix, "frames", video_name)

        # We store the videos with only 16 or 32 frames for testing,
        # since directly reading the whold videos cost a lot of time.
        # You can also read the whole video via self.read_video(video_path)
        torch_imgs = self.read_frame(video_path)
        duration_category = self.data_list[idx]['duration_category']
        qa_list = []
        for qa in self.data_list[idx]['questions']:
            qa_list.append(self.qa_template(qa))

        subtitle = ""
        try:
            subtitle_path = os.path.join(self.data_prefix, "subtitle", video_name + ".vtt")
            if os.path.exists(subtitle_path):
                subtitle = read_vtt_and_concatenate(subtitle_path, model.mistral_tokenizer, self.max_subtitle_len)
        except Exception:
            subtitle = ""
            print(f"Error for {subtitle_path}")
            
        return {
            'subtitle': subtitle,
            'video': torch_imgs, 
            'qa_list': qa_list,
            'duration_category': duration_category
        }
    

def infer_mme(
        data_sample, system="", 
        question_prompt='', # add in the end of question
        answer_prompt=None, # add in the begining of answer
        return_prompt='',  # add in the begining of return message
        system_q=False, # whether add question in the system prompt for QFormer
        print_res=True,
        system_llm=False,
        add_subtitle=False,
    ):
    assert system_q == False, "do not support system_q now"
    video = data_sample["video"]
    T_, C, H, W = video.shape
    video = video.reshape(1, T_, C, H, W).to("cuda:0")
    
    video_list = []
    with torch.no_grad():
        if system_q:
            raise NotImplementedError
        else:
            video_emb, _, _ = model.encode_img(video, system)
    video_list.append(video_emb[0])

    pred_list = []
    gt_list = []
    for idx, qa in enumerate(data_sample['qa_list']):
        print(f"----------qa_{idx}---------", flush=True)
        chat = EasyDict({
            "system": system,
            "roles": ("[INST]", "[/INST]"),
            "messages": [],
            "sep": ""
        })
    
        if add_subtitle:
            if data_sample['subtitle'] != '':
                subtitle = f"This video's subtitles are listed below: {data_sample['subtitle']}"
                chat.messages.append([chat.roles[0], f"{subtitle}\n<Video><VideoHere></Video> [/INST]"])
            else:
                chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> [/INST]"])
        else:
            chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> [/INST]"])
    
        if system_llm:
            prompt = system + qa[0] + question_prompt
        else:
            prompt = qa[0] + question_prompt
        
        ask(prompt, chat)
    
        llm_message = answer(
            conv=chat, model=model, do_sample=False, 
            img_list=video_list, max_new_tokens=100, 
            answer_prompt=answer_prompt, print_res=print_res
        )[0]
        # remove potential explanation
        llm_message = return_prompt + llm_message.strip().split('\n')[0]
        print(f"Pred: {llm_message}", flush=True)
        print(f"GT: {qa[1]}", flush=True)
        pred_list.append(llm_message[1])
        gt_list.append(qa[1][1])
    return pred_list, gt_list

    
#  position embedding
num_frame = 16
resolution = 224
new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
model.vision_encoder.encoder.pos_embed = new_pos_emb

data_dir = "your_data_path/videomme"
anno_path =  "your_data_path/Video-MME.json"
dataset = MME_dataset(
    data_dir, 
    anno_path, 
    num_segments=num_frame, resolution=resolution
)

with open(anno_path, 'r') as f:
    res_json_data = json.load(f)

save_path = "./demo/videomme/your_prediction"

correct = 0
total = 0
res_list = []
acc_dict = {}

for idx, example in enumerate(tqdm(dataset)):
    duration_category = example['duration_category']
    if duration_category not in acc_dict:
        acc_dict[duration_category] = [0, 0] # correct, total
    qa_count = len(example['qa_list'])
    acc_dict[duration_category][1] += qa_count
    total += qa_count
    pred_list, gt_list = infer_mme(
        example, 
        "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n",
        question_prompt="\nOnly give the best option.",
        answer_prompt="Best option:(",
        return_prompt='(',
        system_q=False,
        print_res=False,
        system_llm=True,
        # add_subtitle=True, # Comment this line to add subtitles, we use the whole subtitles by default.
    )
    res_list.append({
        'pred': pred_list,
        'gt': gt_list
    })
    qa_idx = 0
    for pred, gt in zip(pred_list, gt_list):
        if pred == gt:
            acc_dict[duration_category][0] += 1
            correct += 1
        res_json_data[idx]['questions'][qa_idx]['response'] = pred
        qa_idx += 1
    print(f"Part  Acc: {acc_dict[duration_category][0] / acc_dict[duration_category][1] * 100 :.2f}%")
    print(f"Total Acc: {correct / total * 100 :.2f}%")
    print('-' * 50, duration_category, '-' * 50)

with open(f"{save_path}.json", "w") as f:
    json.dump({
        "acc_dict": acc_dict,
        "res_list": res_list
    }, f)

with open(f"{save_path}_full.json", "w") as f:
    json.dump(res_json_data, f)

# Then you can run https://github.com/BradyFU/Video-MME/blob/main/evaluation/eval_your_results.py to get the score
# python3 eval.py --results_file your_prediction_full.json --video_duration_type short,medium,long