In [None]:
from project.dataset.collate import DataCollatorWithPadding
from project.dataset.prepare import MomentRetrievalDataset
from project.trainer.lightning import VideoLlavaModelPLModule
from project.trainer.peft import find_all_linear_names

from transformers import (
    VideoLlavaProcessor,
    BitsAndBytesConfig,
    VideoLlavaForConditionalGeneration,
    LlamaForCausalLM
)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch
from dataclasses import dataclass
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
import argparse
import logging
import os
import sys
from datetime import datetime
from lightning.pytorch.strategies import DeepSpeedStrategy

In [None]:
@dataclass
class Config:
    lora_r: int = 8
    lora_alpha: int = 16
    batch_size: int = 2
    max_epoch: int = 2
    val_check_interval: float = 0.25
    learning_rate: float = 2e-5
    dataset_dir: str = "datasets/processed"
    num_frames: int = 14
    num_worker: int = 2
    hub_repo: str = "jwnt4/finetune-videollava-qlora"
    accumulate_grad_batches: int = 2
    limit_val_batches: float = 24

args = Config

In [None]:
torch.set_float32_matmul_precision('medium')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
stream_handler = logging.StreamHandler(sys.stdout)

if not os.path.isdir("logs"):
    os.makedirs("logs")

log_file = f"logs/{str(datetime.now()).replace(' ', '_')}.log"
file_handler = logging.FileHandler(log_file)

log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S")
stream_handler.setFormatter(log_formatter)

logger.addHandler(stream_handler)
logger.addHandler(file_handler)

In [None]:
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", use_fast=False)
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
processor.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

In [None]:
base_dir = args.dataset_dir.split("/")[0]
processed_dir = args.dataset_dir.split("/")[1]
dp = MomentRetrievalDataset(
    base_dir=base_dir, processed_dir=processed_dir, num_frames=args.num_frames, num_worker=2, processor=processor
)

In [None]:
dataset = None
try:
    dataset = load_from_disk(f"{args.dataset_dir}/moment_retrieval/timestamp/{args.num_frames}_frames")
except:
    dataset = None
if dataset is None:
    dataset = dp.prepare_dataset(use_frame=False)

In [None]:
train_dataset = dataset['train'].select(range(40))
eval_dataset = dataset['validation'].select(range(24))
train_dataloader = DataLoader(dataset['train'], collate_fn=DataCollatorWithPadding(processor), batch_size=args.batch_size, shuffle=False, num_workers=2)
eval_dataloader = DataLoader(dataset['validation'], collate_fn=DataCollatorWithPadding(processor), batch_size=args.batch_size, shuffle=False, num_workers=2)


In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = VideoLlavaForConditionalGeneration.from_pretrained(
    "LanguageBind/Video-LLaVA-7B-hf",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    device_map="auto"
)

model.generation_config.max_new_tokens = 32
model.config.return_dict = True

In [None]:
lora_config = LoraConfig(
    r=args.lora_r,
    lora_alpha=args.lora_alpha,
    lora_dropout=0.1,
    target_modules=find_all_linear_names(model),
    init_lora_weights="gaussian",
    task_type="CAUSAL_LM"
) 

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

In [None]:
from lightning import LightningModule
from project.trainer.metrics import ao_exact_score, mr_iou_score
from deepspeed.ops.adam import DeepSpeedCPUAdam
from torch import Tensor
from bitsandbytes.optim.adam import Adam8bit


class VideoLlavaModelPLModule(LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.save_hyperparameters(ignore=["model"])
        self.config = config
        self.processor = processor
        self.model = model


    def training_step(self, batch):
        input_ids: Tensor
        attention_mask: Tensor
        pixel_values_videos: Tensor
        labels: Tensor
        input_ids, attention_mask, pixel_values_videos, labels = batch

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values_videos=pixel_values_videos,
            labels=labels
        )
        loss = outputs.loss

        self.log("train_loss", loss)

        return loss


    def validation_step(self, batch):

        input_ids, attention_mask, pixel_values_videos, labels, frame_info = batch

        # autoregressively generate token IDs
        generated_ids = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values_videos=pixel_values_videos,
            max_new_tokens=50,
            do_sample=False,
        )
        # turn them back into text, chopping of the prompt
        predictions = self.processor.batch_decode(
            generated_ids[:, input_ids.size(1):], 
            skip_special_tokens=True, clean_up_tokenization_spaces=True)
        
        frame_info = batch[-1]
        score, correct = mr_iou_score(predictions, frame_info, labels) 
        print(score)
            
        self.log("val_accuracy", score)

        return correct


    def configure_optimizers(self):
        # use 8 bit optimizer
        optimizer = Adam8bit(self.parameters(), min_8bit_size=4096, lr=self.config.get("lr"))
        # optimizer = DeepSpeedCPUAdam(self.parameters(), lr=2e-5)
        # optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.get("lr"))

        return optimizer

In [None]:
module = VideoLlavaModelPLModule(
    config={
        "lr": args.learning_rate
    },
    processor=processor,
    model=model
)

In [None]:
limit_val_batches = (args.limit_val_batches // args.batch_size) * args.batch_size
train_conf = {
    "max_epochs": args.max_epoch,
    "accumulate_grad_batches": 1,
    "limit_val_batches": int(limit_val_batches),
    "val_check_interval": args.val_check_interval,
    "precision": "16-mixed",
    "gradient_clip_val": 1.0,
    "num_sanity_val_steps": None
}
logger.info(str(train_conf))

In [None]:
trainer = Trainer(
    **train_conf,
    accelerator="auto",
    devices=[0],
    # callbacks=callbacks,
)

In [None]:
trainer.validate(module, eval_dataloader)

In [None]:
trainer.fit(module, train_dataloader, eval_dataloader)