# Define variables

In [2]:
MAX_LENGTH = 1536
MODEL_ID = "llava-hf/llava-v1.6-vicuna-13b-hf"

## Load processor

In [3]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right

2025-06-25 01:20:00.863965: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-25 01:20:00.873235: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750807200.884902 3931450 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750807200.888500 3931450 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1750807200.898982 3931450 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

## Load model

In [4]:
from transformers import BitsAndBytesConfig, LlavaNextForConditionalGeneration
import torch
import gc

USE_LORA = False
USE_QLORA = False

## Load model

# Three options for training, from the lowest precision training to the highest precision training:
# - QLora
# - Standard Lora
# - Full fine-tuning
if USE_QLORA or USE_LORA:
    if USE_QLORA:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
        )
    model = LlavaNextForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float16,
        quantization_config=bnb_config,
    )
else:
    # for full fine-tuning, we can speed up the model using Flash Attention
    # only available on certain devices, see https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features
    model = LlavaNextForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float16,
        _attn_implementation="flash_attention_2",
    )

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

## Apply PEFT

After loading the base model, we're going to add LoRa adapter layers. We're going to only train these adapter layers (the base model is kept frozen).

In [5]:
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model


def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['multi_modal_projector', 'vision_model']
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)


lora_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=find_all_linear_names(model),
    init_lora_weights="gaussian",
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

## Create PyTorch dataset

In [6]:
from torch.utils.data import Dataset
from typing import Any, Dict
import random
from PIL import Image
from torchvision import transforms as T
from torchvision.transforms import InterpolationMode

class LlavaNextDataset(Dataset):
    """
    PyTorch Dataset for LLaVa-NeXT. This class takes a HuggingFace Dataset as input.

    Each row, consists of image path(png/jpg/jpeg) and ground truth data (json/jsonl/txt).
    """

    def __init__(
        self,
        dataset: Dataset,
        split: str = "train",
        sort_json_key: bool = True,
    ):
        super().__init__()

        self.split = split
        self.sort_json_key = sort_json_key

        self.dataset = dataset[split]
        self.dataset_length = len(self.dataset)
        self.prompt = """You are a system that checks if a photo meets passport photo requirements.
Respond in exactly one of two ways:
- "Acceptable"
- "Unacceptable: <One ore more criteria>"

Criteria to consider: 
Format: Head size must be just right and be in the centre of the frame
Photo quality:  In sharp focus and clear. Neutral colour, natural skin tones, no red eyes.
Lighting: Appropriate brightness and contrast.  Balanced lighting, no shadows or flash reflections on face.
Eyes: Directly looking at the camera. Eyes open and clearly visible, no hair across the eyes.
Pose: Face must be in the centre. Portrait style and tilted positions are not acceptable. The photograph must show both sides of the face evenly
Background: Plain light-coloured (single-coloured) background. The photographed person must be shown alone with clear background
Glasses: Eyes must be showed clearly with no flash reflections on glasses.  No tinted glasses. Frames must not cover any part of the eyes.
Head coverings: Head coverings are not permitted except for religious reasons. Facial features from bottom of the chin to top of forehead and both sides of the face must be clearly shown.
Facial Expression: Facial expression must be neutral. Mouth must be closed.
"""
        self.transform = T.Compose([
            T.Resize((processor.image_processor.size["shortest_edge"],   # 336
                      processor.image_processor.size["shortest_edge"]),
                     interpolation=InterpolationMode.BICUBIC),
            T.ToTensor(),                                # (3, H, W), 0-1 floats
        ])

    def __len__(self) -> int:
        return self.dataset_length

    def __getitem__(self, idx: int) -> Dict:
        record = self.dataset[idx]
        instruction = self.prompt
        output = record["output"]

        # Load and convert to RGB
        image = record['image'].convert("RGB")
        if self.transform:
            image = self.transform(image)

        return {
            "image": image,
            "instruction": instruction,
            "output": output
        }

Let's instantiate the PyTorch datasets:

In [7]:
from datasets import load_dataset
dataset = load_dataset('parquet', data_files={"train": "./data/train/dataset.parquet"})
train_dataset = LlavaNextDataset(dataset,  split="train", sort_json_key=False)

As always, it's important to check your data. Let's check the first example:

In [7]:
train_example = train_dataset[0]
print(train_example)

{'image': tensor([[[0.9490, 0.9529, 0.9529,  ..., 0.9529, 0.9451, 0.9451],
         [0.9490, 0.9451, 0.9451,  ..., 0.9529, 0.9490, 0.9451],
         [0.9490, 0.9451, 0.9451,  ..., 0.9569, 0.9529, 0.9490],
         ...,
         [0.9451, 0.9412, 0.9412,  ..., 0.0039, 0.0039, 0.0039],
         [0.9451, 0.9412, 0.9255,  ..., 0.0039, 0.0039, 0.0000],
         [0.9412, 0.9490, 0.8824,  ..., 0.0039, 0.0039, 0.0078]],

        [[0.9569, 0.9529, 0.9569,  ..., 0.9569, 0.9529, 0.9490],
         [0.9490, 0.9490, 0.9529,  ..., 0.9569, 0.9529, 0.9490],
         [0.9490, 0.9490, 0.9490,  ..., 0.9608, 0.9529, 0.9529],
         ...,
         [0.9569, 0.9490, 0.9608,  ..., 0.0000, 0.0039, 0.0078],
         [0.9569, 0.9529, 0.9451,  ..., 0.0000, 0.0039, 0.0157],
         [0.9608, 0.9608, 0.9137,  ..., 0.0039, 0.0039, 0.0078]],

        [[0.9569, 0.9608, 0.9569,  ..., 0.9529, 0.9451, 0.9412],
         [0.9608, 0.9569, 0.9529,  ..., 0.9490, 0.9451, 0.9451],
         [0.9569, 0.9529, 0.9529,  ..., 0.9529, 

## Define collate functions

In [8]:
def train_collate_fn(examples):
    images, texts = [], []

    for ex in examples:
        image       = ex["image"]
        instruction = ex["instruction"]
        output      = ex["output"]

        images.append(image)

        conversation = [
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": instruction},
            ]},
            {"role": "assistant", "content": [
                {"type": "text", "text": output},
            ]},
        ]
        texts.append(processor.apply_chat_template(conversation))

    batch = processor(
        text=texts,
        images=images,
        padding=True,
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt",
    )

    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    batch["labels"] = labels
    return (
        batch["input_ids"],
        batch["attention_mask"],
        batch["pixel_values"],
        batch["image_sizes"],
        batch["labels"],
    )

def eval_collate_fn(examples):
    # We only feed the prompt to the model, so we don't add assistant's turn
    # Rather we indicate to `add_generation_prompt=True`

    images = []
    texts = []
    answers = []
    for example in examples:
        image, instruction, output = example
        images.append(image)

        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": instruction},
                ],
            },
        ]
        text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
        texts.append(text_prompt)
        answers.append(output)

    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    pixel_values = batch["pixel_values"]
    image_sizes = batch["image_sizes"]

    return input_ids, attention_mask, pixel_values, image_sizes, answers

## Define PyTorch LightningModule

In [9]:
import pytorch_lightning as L
from torch.utils.data import DataLoader
import re
from nltk import edit_distance
import numpy as np


class LlavaModelPLModule(L.LightningModule):
    def __init__(self, config, processor, model, train_dataset):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model
        self.train_dataset = train_dataset

        self.batch_size = config.get("batch_size")

    def training_step(self, batch, batch_idx):

        input_ids, attention_mask, pixel_values, image_sizes, labels = batch

        outputs = self.model(input_ids=input_ids,
                            attention_mask=attention_mask,
                            pixel_values=pixel_values,
                            image_sizes=image_sizes,
                            labels=labels
                          )
        loss = outputs.loss

        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx, dataset_idx=0):

        input_ids, attention_mask, pixel_values, image_sizes, answers = batch

        # autoregressively generate token IDs
        generated_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask,
                                       pixel_values=pixel_values, image_sizes=image_sizes, max_new_tokens=MAX_LENGTH)
        # turn them back into text, chopping of the prompt
        # important: we don't skip special tokens here, because we want to see them in the output
        predictions = self.processor.batch_decode(generated_ids[:, input_ids.size(1):], skip_special_tokens=True)

        scores = []
        for pred, answer in zip(predictions, answers):
            pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
            scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))

            if self.config.get("verbose", False) and len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f" Normed ED: {scores[0]}")

        self.log("val_edit_distance", np.mean(scores))

        return scores

    def configure_optimizers(self):
        # you could also add a learning rate scheduler if you want
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.get("lr"))

        return optimizer

    def train_dataloader(self):
        return DataLoader(self.train_dataset, collate_fn=train_collate_fn, batch_size=self.batch_size, shuffle=True, num_workers=4)

# Configure, train and save

In [None]:
config = {"max_epochs": 1,
          "check_val_every_n_epoch": 1,
          "gradient_clip_val": 1.0,
          "accumulate_grad_batches": 8,
          "lr": 2e-5,
          "batch_size": 1,
          "num_nodes": 1,
          "warmup_steps": 50,
          "result_path": "./result",
          "verbose": True,
}

model_module = LlavaModelPLModule(config, processor, model, train_dataset)

trainer = L.Trainer(
        enable_checkpointing=False, 
        accelerator="gpu",
        devices=[0],
        max_epochs=config.get("max_epochs"),
        accumulate_grad_batches=config.get("accumulate_grad_batches"),
        check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
        gradient_clip_val=config.get("gradient_clip_val"),
        precision="16-mixed",
        limit_val_batches=5,
        num_sanity_val_steps=0
)

trainer.fit(model_module)


peft_model = model_module.model

# ✨ option A: adapters only (tiny file)
peft_model.save_pretrained("./fine-tuned-weights/lr2e-5___")

In [1]:
!rm -rf lightning_logs