In [1]:
from project.dataset.collate import DataCollatorWithPadding
from project.dataset.prepare import DatasetPreparer
from project.trainer.peft import find_all_linear_names

from transformers import (
    VideoLlavaProcessor,
    BitsAndBytesConfig,
    VideoLlavaForConditionalGeneration
)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch
from dataclasses import dataclass
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
import argparse
import logging
import os
import sys
from datetime import datetime

In [2]:
from lightning import LightningModule
from project.trainer.metrics import ao_exact_score
from bitsandbytes.optim import Adam8bit


class VideoLlavaModelPLModule(LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.save_hyperparameters(ignore=['model'])
        self.config = config
        self.processor = processor
        self.model = model

    def training_step(self, batch):

        input_ids, attention_mask, pixel_values_videos, labels = batch

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values_videos=pixel_values_videos,
            labels=labels
        )
        loss = outputs.loss
        print(loss)

        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch):

        input_ids, attention_mask, pixel_values_videos, labels = batch
        is_ao = len(labels[0][0]) == 1 

        if not is_ao:
            frame_info = batch[-1]

        # autoregressively generate token IDs
        generated_ids = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values_videos=pixel_values_videos,
            max_new_tokens=50,
            do_sample=False,
        )
        # turn them back into text, chopping of the prompt
        predictions = self.processor.batch_decode(
            generated_ids[:, input_ids.size(1):], 
            skip_special_tokens=True, clean_up_tokenization_spaces=True)

        if is_ao:
            score, correct = ao_exact_score(predictions, labels)
        else:
            score = 1
            correct = len(predictions)
        self.log("val_accuracy", score)

        return correct

    def configure_optimizers(self):
        # use 8 bit optimizer
        optimizer = Adam8bit(self.parameters(), min_8bit_size=4096, lr=self.config.get("lr"))
        # optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.get("lr"))

        return optimizer

In [3]:
@dataclass
class Config:
    lora_r: int = 4
    lora_alpha: int = 8
    batch_size: int = 1
    max_epoch: int = 2
    val_check_interval: float = 0.25
    learning_rate: float = 2e-5
    dataset_dir: str = "datasets/processed"
    num_frames: int = 14
    num_worker: int = 2
    hub_repo: str = "jwnt4/finetune-videollava-qlora"
    accumulate_grad_batches: int = 1
    limit_val_batches: float = 16

args = Config

In [4]:
torch.set_float32_matmul_precision('medium')

In [5]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
stream_handler = logging.StreamHandler(sys.stdout)

if not os.path.isdir("logs"):
    os.makedirs("logs")

log_file = f"logs/{str(datetime.now()).replace(' ', '_')}.log"
file_handler = logging.FileHandler(log_file)

log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S")
stream_handler.setFormatter(log_formatter)

logger.addHandler(stream_handler)
logger.addHandler(file_handler)


In [6]:
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", use_fast=False)
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
processor.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

In [7]:
base_dir = args.dataset_dir.split("/")[0]
processed_dir = args.dataset_dir.split("/")[1]
dp = DatasetPreparer(
    base_dir=base_dir, processed_dir=processed_dir, num_frames=args.num_frames, num_worker=2, processor=processor
)

In [8]:
dataset = None
try:
    dataset = load_from_disk(f"{args.dataset_dir}/action_ordering_v2/robust/{args.num_frames}_frames")
except:
    dataset = None
if dataset is None:
    dataset = dp.prepare_dataset('action_ordering_v2', use_robust=True)

In [9]:
    train_dataloader = DataLoader(dataset['train'], collate_fn=DataCollatorWithPadding(processor), batch_size=args.batch_size, shuffle=False, num_workers=2)
    eval_dataloader = DataLoader(dataset['validation'], collate_fn=DataCollatorWithPadding(processor), batch_size=args.batch_size, shuffle=False, num_workers=2)

In [10]:
# Define quantized model in 
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = VideoLlavaForConditionalGeneration.from_pretrained(
    "LanguageBind/Video-LLaVA-7B-hf",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    device_map="auto"
)
model.generation_config.max_new_tokens = 40

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [11]:
lora_config = LoraConfig(
    r=args.lora_r,
    lora_alpha=args.lora_alpha,
    lora_dropout=0.1,
    target_modules=find_all_linear_names(model),
    init_lora_weights="gaussian"
) 

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
model

PeftModel(
  (base_model): LoraModel(
    (model): VideoLlavaForConditionalGeneration(
      (video_tower): CLIPVisionModel(
        (vision_model): CLIPVisionTransformer(
          (embeddings): CLIPVisionEmbeddings(
            (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
            (position_embedding): Embedding(257, 1024)
          )
          (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (encoder): CLIPEncoder(
            (layers): ModuleList(
              (0-23): 24 x CLIPEncoderLayer(
                (self_attn): CLIPSdpaAttention(
                  (k_proj): lora.Linear4bit(
                    (base_layer): Linear4bit(in_features=1024, out_features=1024, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1024, out_

In [12]:
print(model.config)
print(model.generation_config)

VideoLlavaConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "LanguageBind/Video-LLaVA-7B-hf",
  "architectures": [
    "VideoLlavaForConditionalGeneration"
  ],
  "ignore_index": -100,
  "image_seq_length": 256,
  "image_token_index": 32000,
  "model_type": "video_llava",
  "pad_token_id": 32002,
  "projector_hidden_act": "gelu",
  "quantization_config": {
    "_load_in_4bit": true,
    "_load_in_8bit": false,
    "bnb_4bit_compute_dtype": "bfloat16",
    "bnb_4bit_quant_storage": "uint8",
    "bnb_4bit_quant_type": "nf4",
    "bnb_4bit_use_double_quant": true,
    "llm_int8_enable_fp32_cpu_offload": false,
    "llm_int8_has_fp16_weight": false,
    "llm_int8_skip_modules": null,
    "llm_int8_threshold": 6.0,
    "load_in_4bit": true,
    "load_in_8bit": false,
    "quant_method": "bitsandbytes"
  },
  "text_config": {
    "_name_or_path": "lmsys/vicuna-7b-v1.5",
    "architectures": [
      "LlamaForCausalLM"
    ],
    "max_position_embeddings": 4096,
    "model_ty

In [13]:
module = VideoLlavaModelPLModule(
    config={
        "lr": args.learning_rate
    },
    processor=processor,
    model=model
)

early_stopping = EarlyStopping(monitor="val_accuracy", verbose=False, mode="min")
model_checkpoint = ModelCheckpoint(
    monitor='val_accuracy',
    dirpath='output/',
    filename='videollava-7b-ao-{epoch:02d}-{val_accuracy:.2f}'+f"lora_r{args.lora_r}-lora_alpha{args.lora_alpha}"
)
callbacks = [
    early_stopping, model_checkpoint
]

In [None]:
limit_val_batches = (args.limit_val_batches // args.batch_size) * args.batch_size
train_conf = {
    "max_epochs": args.max_epoch,
    "accumulate_grad_batches": args.accumulate_grad_batches,
    "limit_val_batches": int(limit_val_batches),
    "val_check_interval": args.val_check_interval,
    "precision": "16-mixed",
    "gradient_clip_val": 1.0,
    "num_sanity_val_steps": int(args.batch_size * 8)
}
logger.info(str(train_conf))

trainer = Trainer(
    **train_conf,
    accelerator="auto",
    devices=[0],
    callbacks=callbacks,
    strategy="deepspeed_stage_2_offload"
)

In [3]:
!python deepspeed-finetune-qlora-ao.py

[2024-12-14 19:51:06,260] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|██████████████████| 3/3 [00:05<00:00,  1.78s/it]
/opt/conda/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
12/14/2024 19:51:19 - INFO - {'max_epochs': 2, 'accumulate_grad_batches': 1, 'limit_val_batches': 16, 'val_check_interval': 0.25, 'precision': '16-mixed', 'gradient_clip_val': 1.0, 'num_sanity_val_steps': 8}
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logge