In [1]:
import os
root_dir = os.path.join(os.getcwd(), "..")
import sys
sys.path.append(root_dir)
from vtimellm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from vtimellm.conversation import conv_templates, SeparatorStyle
from vtimellm.model.builder import load_pretrained_model, load_lora
from vtimellm.utils import disable_torch_init
from vtimellm.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria, VideoExtractor
from PIL import Image
import requests
from io import BytesIO
from transformers import TextStreamer
from easydict import EasyDict as edict
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    from PIL import Image
    BICUBIC = Image.BICUBIC
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize
import numpy as np
import clip
import torch

In [2]:
args = edict()
args.model_base = "/path/to/vicuna-7b-v1.5"
args.clip_path = os.path.join(root_dir, "checkpoints/clip/ViT-L-14.pt")
args.pretrain_mm_mlp_adapter = os.path.join(root_dir, "checkpoints/vtimellm-vicuna-v1-5-7b-stage1/mm_projector.bin")
args.stage2 = os.path.join(root_dir, "checkpoints/vtimellm-vicuna-v1-5-7b-stage2")
args.stage3 = os.path.join(root_dir, "checkpoints/vtimellm-vicuna-v1-5-7b-stage3")
args.video_path = os.path.join(root_dir, "images/demo.mp4")
args.temperature = 0.05

In [3]:
def inference(model, tokenizer, context_len, image, args):
    conv = conv_templates['v1'].copy()
    roles = conv.roles
    first = True
    while True:
        try:
            inp = input(f"{roles[0]}: ")
        except EOFError:
            inp = ""
        if not inp:
            print("exit...")
            break

        print(f"{roles[1]}: ", end="")

        if first:
            # first message
            inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
            conv.append_message(conv.roles[0], inp)
            first = False
        else:
            # later messages
            conv.append_message(conv.roles[0], inp)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 # plain:sep(###) v1:sep2(None)
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
        streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

        with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images=image[None,].cuda(),
                do_sample=True,
                temperature=args.temperature,
                max_new_tokens=1024,
                streamer=streamer,
                use_cache=True,
                stopping_criteria=[stopping_criteria]
            )

        outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
        conv.messages[-1][-1] = outputs

In [4]:
disable_torch_init()
tokenizer, model, context_len = load_pretrained_model(args, args.stage2, args.stage3)
model = model.cuda()
model = model.to(torch.float16)

You are using a model of type llama to instantiate a model of type VTimeLLM. This is not supported for all configurations of models and can yield errors.


Loading VTimeLLM from base model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



load mlp: /DATA/DATANAS2/bhuang/link/gitlab/vtimellm/docs/../checkpoints/vtimellm-vicuna-v1-5-7b-stage1/mm_projector.bin
Loading stage2 weights...
Loading LoRA weights...
Merging stage2 weights...
Loading stage3 weights...
Loading LoRA weights...
Merging stage3 weights...


In [5]:
clip_model, _ = clip.load(args.clip_path)
clip_model.eval()
clip_model = clip_model.cuda()

video_loader = VideoExtractor(N=100)
_, images = video_loader.extract({'id': None, 'video': args.video_path})

transform = Compose([
    Resize(224, interpolation=BICUBIC),
    CenterCrop(224),
    Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])

# print(images.shape) # <N, 3, H, W>
images = transform(images / 255.0)
images = images.to(torch.float16)
with torch.no_grad():
    features = clip_model.encode_image(images.to('cuda'))



In [8]:
inference(model, tokenizer, context_len, features, args)

USER:  Explain why this video is funny.


ASSISTANT: The video is funny because the bear is dancing to the music and moving its arms and legs in a funny way. The bear's movements are exaggerated and comical, making it difficult for the person to keep up with the beat. The bear's facial expressions and body language add to the humor of the video.


USER:  Is it a real bear?


ASSISTANT: No, it is not a real bear. It is a costume worn by a person.


USER:  


exit...
