In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
root_dir = os.path.join(os.getcwd(), "..")
print(root_dir)
import sys
sys.path.append(root_dir)

from vtimellm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from vtimellm.conversation import conv_templates, SeparatorStyle
from vtimellm.model.builder import load_pretrained_model, load_lora
from vtimellm.train.dataset import preprocess_glm
from vtimellm.utils import disable_torch_init
from vtimellm.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria, VideoExtractor
from PIL import Image
import requests
from io import BytesIO
from transformers import TextStreamer
from easydict import EasyDict as edict
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    from PIL import Image
    BICUBIC = Image.BICUBIC
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize
import numpy as np
import clip
import torch

/DATA/DATANAS2/bhuang/link/gitlab/vtimellm/docs/..


In [3]:
model_version = 'chatglm3-6b' # vicuna-v1-5-7b
args = edict()
args.model_base = "/DATA/DATANAS2/bhuang/data/vicuna-7b-v1.5"
if model_version == 'chatglm3-6b':
    args.model_base = os.path.join(root_dir, 'checkpoints/' + model_version)
args.clip_path = os.path.join(root_dir, "checkpoints/clip/ViT-L-14.pt")
args.pretrain_mm_mlp_adapter = os.path.join(root_dir, f"checkpoints/vtimellm-{model_version}-stage1/mm_projector.bin")
args.stage2 = os.path.join(root_dir, f"checkpoints/vtimellm-{model_version}-stage2")
args.stage3 = os.path.join(root_dir, f"checkpoints/vtimellm-{model_version}-stage3")
args.temperature = 0.05

In [4]:
disable_torch_init()
tokenizer, model, context_len = load_pretrained_model(args, args.stage2, args.stage3)
model = model.cuda()
model = model.to(torch.float16)

You are using a model of type chatglm to instantiate a model of type VTimeLLM_ChatGLM. This is not supported for all configurations of models and can yield errors.


Loading VTimeLLM from base model...


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

load mlp: /DATA/DATANAS2/bhuang/link/gitlab/vtimellm/docs/../checkpoints/vtimellm-chatglm3-6b-stage1/mm_projector.bin
Loading stage2 weights...
Loading LoRA weights...
Merging stage2 weights...
Loading stage3 weights...
Loading LoRA weights...
Merging stage3 weights...


In [13]:
clip_model, _ = clip.load(args.clip_path)
clip_model.eval()
clip_model = clip_model.cuda()

video_loader = VideoExtractor(N=100)

transform = Compose([
    Resize(224, interpolation=BICUBIC),
    CenterCrop(224),
    Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])


In [None]:
args.video_path = '/DATA/DATANAS2/bhuang/link/1.mp4'
_, images = video_loader.extract({'id': None, 'video': args.video_path})
# print(images.shape) # <N, 3, H, W>
images = transform(images / 255.0)
images = images.to(torch.float16)
with torch.no_grad():
    features = clip_model.encode_image(images.to('cuda'))

In [14]:
def inference(model, tokenizer, context_len, image, args):
    source = []
    first = True
    while True:
        try:
            inp = input(f"USER: ")
        except EOFError:
            inp = ""
        if not inp:
            print("exit...")
            break

        print(f"ASSISTANT:", end="")

        if first:
            # first message
            inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
            first = False
        
        source.append({
            'from': "human",
            'value': inp
        })
        input_ids = preprocess_glm([source], tokenizer)['input_ids'].cuda()
        input_ids[0][-1] = tokenizer.get_command("<|assistant|>")
        streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

        with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images=image[None,].cuda(),
                do_sample=True,
                temperature=args.temperature,
                max_new_tokens=1024,
                streamer=streamer,
                use_cache=True,
                eos_token_id=[tokenizer.get_command("<|user|>"), tokenizer.eos_token_id],
            )

        outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:-1]).strip()
        # print(outputs)
        source.append({
            'from': "gpt",
            'value': outputs
        })

In [19]:
inference(model, tokenizer, context_len, features, args)

ASSISTANT:视频中，一名男子在黑暗的房间里，手里拿着一个装满东西的盒子。他打开盒子，里面装满了各种物品。然后，该男子爬上一座高高的建筑物，并从窗户跳入水中。<|user|>
exit...
