In [1]:
import av
import numpy as np

from transformers import Trainer, TrainingArguments, Seq2SeqTrainingArguments, DataCollatorForLanguageModeling
from transformers import AutoProcessor, BitsAndBytesConfig, LlavaNextVideoForConditionalGeneration, DataCollatorWithPadding 
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

import torch
from huggingface_hub import  hf_hub_download
import av
import torch
import numpy as np
from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor, BitsAndBytesConfig
from huggingface_hub import hf_hub_download
import numpy as np
import evaluate
from activity_dataset import get_dataset_splits



MODEL_ID = "llava-hf/LLaVA-NeXT-Video-7B-hf"
MAX_LENGTH = 4096
USE_LORA = False
USE_QLORA = True

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

# Load the model in half-precision
processor = LlavaNextVideoProcessor.from_pretrained(MODEL_ID)
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
    MODEL_ID,
    quantization_config=quantization_config,
    device_map='auto',
)
processor.tokenizer.padding_side = "left"
processor.image_processor.do_rescale = False

#model.config.video_token_index = model.config.image_token_index = 32001

Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.22s/it]


In [3]:
train_dataset, test_dataset = get_dataset_splits()

In [4]:
if USE_QLORA or USE_LORA:
    if USE_QLORA:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )
    model = LlavaNextVideoForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float16,
        quantization_config=bnb_config,
        device_map="auto",
    )

Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.20s/it]


In [5]:
def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['multi_modal_projector', 'vision_model']
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)


lora_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=find_all_linear_names(model),
    init_lora_weights="gaussian",
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

In [6]:
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [7]:
OUTPUT_DIR = "output_llava"
BATCH_SIZE = 1
REPO_ID = "cams01/LLaVa-robot-activity-recognition"


args = TrainingArguments(
    # args related to training
    output_dir = OUTPUT_DIR,
    eval_strategy = 'steps',
    eval_steps=20,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,
    gradient_accumulation_steps = 8,
    learning_rate = 2e-05,
    max_steps = 10, # adjust this depending on your dataset size
    lr_scheduler_type = 'cosine',
    warmup_ratio = 0.1,

    # args related to eval/save
    logging_steps = 20,
    save_strategy = 'steps',
    save_steps=20,
    save_total_limit = 1,
    fp16 = True, # we have the model train and eval with fp16 precision
    fp16_full_eval = True,
    optim = 'adamw_bnb_8bit', # adam in lower-bits to save memory, consider changing to 'adamw_torch' if model is not converging
    # report_to = "wandb", # install wand to use this
    # hub_model_id = REPO_ID,
    # push_to_hub = True, # wel'll push the model to hub after each epoch

    # model that was wrapped for QLORA training with peft will not have arguments listed in its signature
    # so we need to pass lable names explicitly to calculate val loss
    label_names=["labels"],
    # dataloader_num_workers=4, # let's get more workers since iterating on video datasets might be slower in general
)

In [25]:
class LlavaNextVideoDataCollatorWithPadding:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, features):
        padded_inputs = self.processor.tokenizer.pad(
            {
                "input_ids": [feat['input_ids'][0] for feat in features], # each element is one batch only so we slice [0]
                "attention_mask": [feat['attention_mask'][0] for feat in features],
            },
            padding=True,
            return_tensors="pt",
        )

        labels = padded_inputs["input_ids"].clone()
        labels[labels == self.processor.tokenizer.pad_token_id] = -100
        padded_inputs["labels"] = labels
        padded_inputs["pixel_values_videos"] = torch.cat([feat['pixel_values_videos'] for feat in features], dim=0)
        
        print({k:v.shape for k, v in padded_inputs.items()})

        return padded_inputs

In [26]:

trainer = Trainer(
    model = model,
    processing_class = processor,
    data_collator = LlavaNextVideoDataCollatorWithPadding(processor=processor),
    train_dataset = train_dataset,
    eval_dataset = test_dataset,
    args=args,
)

In [27]:
trainer.train()

torch.Size([221, 3, 336, 336])
{'input_ids': torch.Size([1, 4096]), 'attention_mask': torch.Size([1, 4096]), 'labels': torch.Size([1, 4096]), 'pixel_values_videos': torch.Size([1, 221, 3, 336, 336])}
torch.Size([267, 3, 336, 336])


KeyboardInterrupt: 

In [12]:
from pathlib import Path
import av
import torch
import numpy as np
from huggingface_hub import hf_hub_download
import torchvision
from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
from activity_dataset import transformation

model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"

model = LlavaNextVideoForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
).to(0)

processor = LlavaNextVideoProcessor.from_pretrained(model_id)

def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


# define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image", "video") 
conversation = [
    {

        "role": "user",
        "content": [
            {"type": "text", "text": "Describe the video and the activities of the subjects"},
            {"type": "video"},
            ],
    },
]

prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

video_path = Path("atlas_dione_objectdetection\ATLAS_Dione_ObjectDetection\ATLAS_Dione_ObjectDetection_Study_ActionClips\ATLAS_Dione_ObjectDetection_Study_ActionClips\set07\set07V001.mkv")
clip,_,_  = torchvision.io.read_video(video_path, pts_unit = "sec", output_format="TCHW")
clip = torch.stack([transformation(frame) for frame in clip])
inputs_video = processor(text=prompt, videos=clip[:2], padding=True, return_tensors="pt").to(model.device)

output = model.generate(**inputs_video, max_new_tokens=100, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True))


Loading checkpoint shards: 100%|██████████| 3/3 [00:06<00:00,  2.11s/it]


ER: 
Describe the video and the activities of the subjects ASSISTANT: The video depicts a group of people engaging in various activities in a dark room. The focus is on a man who is seen walking towards the camera, and then he turns around and walks away. Another man is seen walking towards the camera and turns around, but he is not as clear as the first man. The video seems to be shot in a dark room with limited lighting, and the subjects are not clearly visible. The activities of the subjects are not clear due to the
