In [1]:
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='3'

In [3]:
import random
from torch.utils.data import Dataset, DataLoader

from video_llama.processors import AlproVideoTrainProcessor
from video_llama.processors.video_processor import load_video
from video_llama.models.ImageBind.data import load_and_transform_audio_data

## Data Preprocessing

In [4]:
class VideoDataset(Dataset):
    
    def __init__(self, annotation, root="", num_frames=8, resize_size=224):
        self.annotation = annotation
        self.root = root
        self.num_frames = num_frames
        self.resize_size = resize_size
        self.transform = AlproVideoTrainProcessor(
            image_size=resize_size,
            n_frms=num_frames,
        ).transform
        
    def __len__(self):
        return len(self.annotation)
        
    def __getitem__(self, index, num_retries=10, device="cpu"):
        result = {}
        for _ in range(num_retries):
            sample = self.annotation[index]
            video_path = "/".join([self.root, sample["video"]])
            try:
                video = load_video(
                    video_path=video_path,
                    n_frms=self.num_frames,
                    height=self.resize_size,
                    width=self.resize_size,
                    sampling ="uniform",
                    return_msg = False,
                )
                result["image"] = self.transform(video)
                result["text"] = sample["text"]
                result["audio"] = load_and_transform_audio_data(
                    [video_path],
                    device=device,
                    clips_per_video=self.num_frames
                )
            except Exception as e:
                print(
                    f"Failed to load sample: {e}.",
                    f"Will randomly sample an example as a replacement."
                )
                index = random.randint(0, len(self) - 1)
                continue
            break
        else:  
            raise RuntimeError(f"Failed to fetch sample after {num_retries} retries.")
        return result

In [6]:
root = "/code/Video-LLaMA"
annotation = [
    {"video": "examples/birthday.mp4", "text": "birthday party"},
    {"video": "examples/boat.mp4", "text": "floating boat"},
]

dataset = VideoDataset(annotation, root)
loader = DataLoader(dataset, batch_size=1, num_workers=1)

## Test

In [7]:
import torch
from omegaconf import OmegaConf
from backbone import VideoLLAMABackbone

In [8]:
device = torch.device("cuda")
config = OmegaConf.load("config.yaml")
model = VideoLLAMABackbone.from_config(config).to(device)

Load first Checkpoint: /code/Video-LLaMA/ckpt_12b/VL_LLaMA_2_13B_Finetuned.pth
Load second Checkpoint: /code/Video-LLaMA/ckpt_12b/AL_LLaMA_2_13B_Finetuned.pth


In [10]:
sample = next(iter(loader))
output = model(sample["image"].to(device))
print(output.shape)

Failed to load sample: Source stream index out of range. Will randomly sample an example as a replacement.
Failed to load sample: Source stream index out of range. Will randomly sample an example as a replacement.
Failed to load sample: Source stream index out of range. Will randomly sample an example as a replacement.
torch.Size([1, 32, 5120])
