In [None]:
from project.dataset.prepare import DatasetPreparer
from project.dataset.collate import DataCollatorWithPadding
from project.trainer.lightning import VideoLlavaModelPLModule
from project.trainer.peft import find_all_linear_names

from transformers import (
    VideoLlavaProcessor,
    BitsAndBytesConfig,
    VideoLlavaForConditionalGeneration
)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch
from dataclasses import dataclass
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

In [None]:
@dataclass
class model_conf:
    model_id = "LanguageBind/Video-LLaVA-7B-hf"
    lora_r = 16

In [None]:
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", use_fast=False)
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
processor.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

In [None]:
dp = DatasetPreparer(base_dir="datasets", processed_dir="processed", processor=processor, num_frames=14)

In [None]:
try:
    dataset = load_from_disk("datasets/processed/action_ordering_v2/robust/14_frames")
except:
    dataset = dp.prepare_dataset('action_ordering_v2', use_robust=True) 

In [None]:
train_dataloader = DataLoader(dataset['train'], collate_fn=DataCollatorWithPadding(processor), batch_size=4, shuffle=False, num_workers=2)
eval_dataloader = DataLoader(dataset['test'], collate_fn=DataCollatorWithPadding(processor), batch_size=4, shuffle=False, num_workers=2)

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_storage=torch.float16,
    bnb_4bit_use_double_quant=True
)
model = VideoLlavaForConditionalGeneration.from_pretrained(
    model_conf.model_id,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
    device_map="auto"
)

In [None]:
lora_config = LoraConfig(
    r=model_conf.lora_r,
    lora_dropout=0.1,
    target_modules=find_all_linear_names(model),
    init_lora_weights="gaussian"
)

In [None]:
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

In [None]:
model

In [None]:
@dataclass
class train_conf:
    max_epoch = 2
    batch_size = 2
    num_nodes = 1
    accumulate_grad_batches = 4
    lr = 2e-5
    limit_val_batches = 32
    val_check_interval = (1/4)
    precision="16-mixed"

In [None]:
module = VideoLlavaModelPLModule(
    config={
        "lr": train_conf.lr
    },
    processor=processor,
    model=model
)

In [None]:
early_stopping = EarlyStopping(monitor="val_accuracy", verbose=False, mode="min")

In [None]:
callbacks = [
    early_stopping
]

trainer = Trainer(
    **(vars(train_conf)),
    callbacks=callbacks
)