In [6]:
from dataset.prepare import DatasetPreparer
from transformers import VideoLlavaProcessor, VideoLlavaImageProcessor
from dataset.collate import DataCollatorWithPadding
from torch import tensor
import numpy as np
import torch
from datasets import load_from_disk
from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML
from torch.utils.data import DataLoader

In [13]:
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", use_fast=False)
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
processor.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

In [3]:
dp = DatasetPreparer(base_dir='test/prepare1', processed_dir='processed', processor=processor, num_frames=14)

In [4]:
try:
    ao = load_from_disk('test/prepare1/processed/action_ordering_v2/robust/14_frames')
except:
    ao = dp.prepare_dataset('action_ordering_v2', use_robust=True)

try:
    mr = load_from_disk('test/prepare1/processed/moment_retrieval_v2/frame/14_frames')
except:
    mr = dp.prepare_dataset('moment_retrieval_v2', use_frame=True, mr_max_actions=1)

In [27]:
train_dataloader = DataLoader(mr['train'], collate_fn=DataCollatorWithPadding(processor), batch_size=4, shuffle=True, num_workers=1)
eval_dataloader = DataLoader(mr['test'], collate_fn=DataCollatorWithPadding(processor), batch_size=4, shuffle=True, num_workers=1)

In [28]:
train = next(iter(train_dataloader))

In [29]:
test = next(iter(eval_dataloader))

In [31]:
(
    test['input_ids'].shape,
    test['attention_mask'].shape,
    test['pixel_values_videos'].shape,
    test['answer']
)

(torch.Size([4, 3841]),
 torch.Size([4, 3841]),
 torch.Size([4, 14, 3, 224, 224]),
 [['2', '14'], ['2', '6'], ['2', '8'], ['2', '10']])

In [33]:
(
    train['input_ids'].shape,
    train['attention_mask'].shape,
    train['pixel_values_videos'].shape,
    train['labels'].shape
)

(torch.Size([4, 3845]),
 torch.Size([4, 3845]),
 torch.Size([4, 14, 3, 224, 224]),
 torch.Size([4, 3845]))

In [None]:
processor.batch_decode(test['input_ids'], skip_special_tokens=True, clean_up_tokenization_space)