<a href="https://colab.research.google.com/github/grach0v/ORPO_LLaVA/blob/main/code/ORPO_LLAVA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers trl datasets huggingface_hub bitsandbytes wandb tqdm



## Mount Drive (optional)

In [2]:
HF_CACHE = "/content/drive/MyDrive/hf_cache"

In [3]:
from google.colab import drive, runtime
import os
drive.mount("/content/drive")

# One shared cache for everything:
HF_CACHE = "/content/drive/MyDrive/hf_cache"
!mkdir -p "$HF_CACHE"

os.environ["HF_HOME"] = HF_CACHE           # generic root
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
os.environ["HF_DATASETS_CACHE"] = HF_CACHE

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Imports

In [4]:
from __future__ import annotations

import contextlib
import os
from pathlib import Path
from typing import Dict, List, Tuple

from tqdm.auto import tqdm

import torch
from torch import Tensor
from torch.nn.functional import log_softmax, softplus
from torch.utils.data import DataLoader
from PIL import Image
from datasets import load_dataset
from transformers import (
    BitsAndBytesConfig,
    LlavaNextForConditionalGeneration,
    LlavaNextProcessor,
    get_cosine_schedule_with_warmup,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import wandb




## Config

In [5]:
MODEL_ID          = "llava-hf/llava-v1.6-mistral-7b-hf"
DATASET_NAME      = "openbmb/RLAIF-V-Dataset"
WANDB_PROJECT     = "llava-qlora-orpo"
OUTPUT_DIR        = "/content/drive/MyDrive/llava_orpo_adapters"

!mkdir -p "$OUTPUT_DIR"
USE_QLORA         = True

TRAIN_BATCH_SIZE  = 2 if USE_QLORA else 1
VAL_BATCH_SIZE    = 16
TEST_BATCH_SIZE   = 16
GRAD_ACC_STEPS    = 4          # effective batch = TRAIN_BATCH_SIZE × GRAD_ACC_STEPS
EPOCHS            = 1
LEARNING_RATE     = 2e-4
WARMUP_RATIO      = 0.03

VAL_RATIO         = 0.05       # 5 % validation
TEST_RATIO        = 0.05       # 5 % test

LORA_R            = 8  if USE_QLORA else 16
LORA_ALPHA        = 16 if USE_QLORA else 32

LORA_DROPOUT      = 0.05
ORPO_BETA         = 0.1

LOG_EVERY_STEPS   = 4
VAL_EVERY_STEPS   = 200

DEVICE = torch.device("cuda")

# Maximum number of tokens for the answer
MAX_ANSWER_TOKENS = 128 # You can adjust this value

## Load the model

In [6]:
processor = LlavaNextProcessor.from_pretrained(MODEL_ID, use_fast=True)
TOKENIZER = processor.tokenizer
EOS_ID = TOKENIZER.eos_token_id

if USE_QLORA:
    # 4‑bit base + gradient‑ckpt
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    base_model = LlavaNextForConditionalGeneration.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto",
    )

    base_model = prepare_model_for_kbit_training(
        base_model, use_gradient_checkpointing=True
    )
    base_model.config.use_cache = False     # must be OFF with grad‑ckpt
else:
    base_model = LlavaNextForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto",
    )
    base_model.config.use_cache = True      # keep cache on


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()

trainable params: 22,151,168 || all params: 7,588,898,816 || trainable%: 0.2919


## Create dataloader

In [8]:
@contextlib.contextmanager
def temporary_padding_side(tokenizer, side: str):
    """Temporarily change padding side (left/right) inside a `with` block."""
    original = tokenizer.padding_side
    tokenizer.padding_side = side
    try:
        yield
    finally:
        tokenizer.padding_side = original


def build_prompt_inputs(images: List[Image.Image], questions: List[str]) -> Dict[str, Tensor]:
    """Tokenise the (question + image placeholder) prompt with left‑padding."""
    conversations = [
        [{"role": "user", "content": [{"type": "text", "text": q}, {"type": "image"}]}]
        for q in questions
    ]
    prompts = [processor.apply_chat_template(c, add_generation_prompt=True) for c in conversations]
    encoded = processor(images=images, text=prompts, padding=True, return_tensors="pt")
    return {k: v.to(DEVICE) for k, v in encoded.items()}


def tokenize_answers(texts: List[str], max_length: int | None = None) -> Tuple[Tensor, Tensor]:
    """Right‑pad assistant answers and append EOS."""

    # Reason for the right pad tokenization -
    # later I will concatenate prompt tokens and potential answer tokens,
    # to get logits in one go, without writing a loop.
    # Having pad tokens in the middle seems very confusing
    # and can be misleading and couse errors in the future.

    with temporary_padding_side(TOKENIZER, "right"):
        encoded = TOKENIZER(
            texts,
            padding=True,
            truncation=True,
            add_special_tokens=False,
            return_tensors="pt",
        )
    ids, mask = encoded["input_ids"], encoded["attention_mask"]
    eos_column = torch.full((ids.size(0), 1), EOS_ID, dtype=ids.dtype)
    ids = torch.cat([ids, eos_column], dim=1)
    mask = torch.cat([mask, torch.ones_like(eos_column)], dim=1)

    if max_length is not None:
        # Trim if longer than max_length
        ids = ids[:, :max_length]
        mask = mask[:, :max_length]

    return ids.to(DEVICE), mask.to(DEVICE)


def collate_fn(batch):
    images = [item["image"] for item in batch]
    questions = [item["question"] for item in batch]
    chosen_texts = [item["chosen"] for item in batch]
    rejected_texts = [item["rejected"] for item in batch]

    prompt_inputs = build_prompt_inputs(images, questions)
    chosen_ids, chosen_mask = tokenize_answers(chosen_texts, max_length=MAX_ANSWER_TOKENS)
    rejected_ids, rejected_mask = tokenize_answers(rejected_texts, max_length=MAX_ANSWER_TOKENS)

    return prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask

In [9]:
raw_dataset = load_dataset(DATASET_NAME, split="train[:20%]") # Load only 20% of the training data
first_split = raw_dataset.train_test_split(test_size=VAL_RATIO + TEST_RATIO, seed=42)
train_dataset = first_split["train"]
val_test_dataset = first_split["test"]
val_fraction_of_tmp = VAL_RATIO / (VAL_RATIO + TEST_RATIO)
second_split = val_test_dataset.train_test_split(test_size=1 - val_fraction_of_tmp, seed=42)
val_dataset = second_split["train"]
test_dataset = second_split["test"]

train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

In [10]:
# Re-calculating the 95th percentile of the 'chosen' column length
chosen_lengths = [len(text) for text in raw_dataset["chosen"]]
chosen_lengths.sort()
quantile_95_length = chosen_lengths[int(len(chosen_lengths) * 0.95)]

# Setting the model's max_length for generation
model.generation_config.max_length = quantile_95_length
print(f"Set model.generation_config.max_length to {model.generation_config.max_length}")

Set model.generation_config.max_length to 932


# Get logits

In [11]:
def get_prompt_cache(prompt_inputs: Dict[str, Tensor], adapters_active: bool):
    """Run the prompt once; return `(last_logit, past_kv)`."""
    ctx = model.disable_adapter() if not adapters_active else contextlib.nullcontext()
    training_mode = model.training
    with ctx:
        model.eval()                       # temporarily enable KV cache
        with torch.no_grad():
            out = model(**prompt_inputs, use_cache=True, return_dict=True)
        if training_mode:                  # restore
            model.train()
    return out.logits[:, -1:, :], out.past_key_values

if USE_QLORA:
    # Overwrite get_prompt_cache
    def get_prompt_cache(
            prompt_inputs: Dict[str, Tensor],
            adapters_active: bool
    ) -> Tuple[Tensor, Tuple]:
        """
        Return (last_prompt_logit, past_key_values) even when the model was
        launched with gradient-checkpointing (QLoRA).

        Strategy
        --------
        • Disable LoRA adapters if `adapters_active` is False.
        • Walk every sub-module and flip `.gradient_checkpointing=False`.
        • Switch to `eval()` and `torch.no_grad()` for the single prompt pass.
        • Run with `use_cache=True` so the KV-cache is generated.
        • Restore *all* flags exactly as they were.
        """
        training_mode = model.training                   # remember .train/.eval state

        # ── snapshot every layer’s gradient-ckpt flag ─────────────────────
        ckpt_layers = []
        for mod in model.modules():
            if hasattr(mod, "gradient_checkpointing"):
                ckpt_layers.append((mod, mod.gradient_checkpointing))
                mod.gradient_checkpointing = False       # turn it off temporarily

        # LoRA on/off context
        adapter_ctx = model.disable_adapter() if not adapters_active else contextlib.nullcontext()

        with adapter_ctx, torch.no_grad():
            model.eval()                                 # no grads, cache allowed
            out = model(**prompt_inputs, use_cache=True, return_dict=True)

        # ── restore layer flags ───────────────────────────────────────────
        for mod, orig_val in ckpt_layers:
            mod.gradient_checkpointing = orig_val
        if training_mode:
            model.train()

        return out.logits[:, -1:, :], out.past_key_values

def sequence_logprob_from_cache(prompt_last_logit: Tensor,
                                past_key_values,
                                continuation_ids: Tensor,
                                continuation_mask: Tensor,
                                adapters_active: bool) -> Tensor:
    """Log‑probability of a continuation (incl. EOS) using an existing cache."""
    ctx = model.disable_adapter() if not adapters_active else contextlib.nullcontext()
    with ctx:
        logits_rest = model(
            input_ids=continuation_ids[:, :-1],
            attention_mask=continuation_mask[:, :-1],
            past_key_values=past_key_values,
            use_cache=False,
            pixel_values=None,
            image_sizes=None,
            return_dict=True,
        ).logits  # (B, N‑1, V)
    full_logits = torch.cat([prompt_last_logit, logits_rest], dim=1)  # (B, N, V)
    log_probs = log_softmax(full_logits, dim=-1)
    token_lp = log_probs.gather(2, continuation_ids.unsqueeze(-1)).squeeze(-1)
    return (token_lp * continuation_mask).sum(dim=-1)  # (B,)


In [None]:
def orpo_loss(lp_theta_c, lp_theta_r, lp_ref_c, lp_ref_r, beta=ORPO_BETA):
    return softplus(-beta * ((lp_theta_c - lp_theta_r) - (lp_ref_c - lp_ref_r))).mean()

# ───────────────────────────────────────────────────────────
# Optimiser, scheduler, wandb
# ───────────────────────────────────────────────────────────
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0.0)
steps_per_epoch = len(train_loader) // GRAD_ACC_STEPS
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    int(steps_per_epoch * EPOCHS * WARMUP_RATIO),
    steps_per_epoch * EPOCHS,
)

wandb.init(project=WANDB_PROJECT, config={k: v for k, v in globals().items() if k.isupper()})

# ───────────────────────────────────────────────────────────
# Training Loop (with cache reuse)
# ───────────────────────────────────────────────────────────

model.train()
acc_steps = 0
running_loss = 0.0  # To accumulate loss for average
for epoch in range(EPOCHS):
    for global_step, batch in tqdm(enumerate(train_loader, 1), total=len(train_loader), desc=f"Epoch {epoch+1}/{EPOCHS}"):
        prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask = batch

        # θ policy
        last_theta, pkv_theta = get_prompt_cache(prompt_inputs, adapters_active=True)
        lp_theta_chosen   = sequence_logprob_from_cache(last_theta, pkv_theta, chosen_ids,   chosen_mask,   adapters_active=True)
        lp_theta_rejected = sequence_logprob_from_cache(last_theta, pkv_theta, rejected_ids, rejected_mask, adapters_active=True)

        # reference policy
        last_ref, pkv_ref = get_prompt_cache(prompt_inputs, adapters_active=False)
        with torch.no_grad():
            lp_ref_chosen   = sequence_logprob_from_cache(last_ref, pkv_ref, chosen_ids,   chosen_mask,   adapters_active=False)
            lp_ref_rejected = sequence_logprob_from_cache(last_ref, pkv_ref, rejected_ids, rejected_mask, adapters_active=False)

        loss = orpo_loss(lp_theta_chosen, lp_theta_rejected, lp_ref_chosen, lp_ref_rejected) / GRAD_ACC_STEPS
        loss.backward()
        acc_steps += 1
        running_loss += loss.item() * GRAD_ACC_STEPS # Accumulate original loss before division


        if acc_steps == GRAD_ACC_STEPS:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step(); scheduler.step(); optimizer.zero_grad(); acc_steps = 0

            if global_step % LOG_EVERY_STEPS == 0:
                avg_loss = running_loss / LOG_EVERY_STEPS
                wandb.log(
                    {"train/orpo_loss": avg_loss,
                     "lr": scheduler.get_last_lr()[0],
                     "train/lp_theta_chosen": lp_theta_chosen.mean().item(),
                     "train/lp_theta_rejected": lp_theta_rejected.mean().item(),
                     "train/lp_ref_chosen": lp_ref_chosen.mean().item(),
                     "train/lp_ref_rejected": lp_ref_rejected.mean().item()},
                    step=global_step
                )
                running_loss = 0.0 # Reset running loss


            if global_step % VAL_EVERY_STEPS == 0:
                model.eval()
                val_losses = []
                val_lp_theta_chosen_list = []
                val_lp_theta_rejected_list = []
                val_lp_ref_chosen_list = []
                val_lp_ref_rejected_list = []
                with torch.no_grad():
                    for vb in tqdm(val_loader, desc="Validation"):
                        p_in, c_id, c_m, r_id, r_m = vb

                        lt_last, lt_pkv = get_prompt_cache(p_in, adapters_active=True)
                        lr_last, lr_pkv = get_prompt_cache(p_in, adapters_active=False)

                        lt_c = sequence_logprob_from_cache(lt_last, lt_pkv, c_id, c_m, adapters_active=True)
                        lt_r = sequence_logprob_from_cache(lt_last, lt_pkv, r_id, r_m, adapters_active=True)
                        lr_c = sequence_logprob_from_cache(lr_last, lr_pkv, c_id, c_m, adapters_active=False)
                        lr_r = sequence_logprob_from_cache(lr_last, lr_pkv, r_id, r_m, adapters_active=False)

                        val_losses.append(orpo_loss(lt_c, lt_r, lr_c, lr_r).item())
                        val_lp_theta_chosen_list.append(lt_c.mean().item())
                        val_lp_theta_rejected_list.append(lt_r.mean().item())
                        val_lp_ref_chosen_list.append(lr_c.mean().item())
                        val_lp_ref_rejected_list.append(lr_r.mean().item())

                wandb.log({
                    "val/orpo_loss": sum(val_losses) / len(val_losses),
                    "val/lp_theta_chosen": sum(val_lp_theta_chosen_list) / len(val_lp_theta_chosen_list),
                    "val/lp_theta_rejected": sum(val_lp_theta_rejected_list) / len(val_lp_theta_rejected_list),
                    "val/lp_ref_chosen": sum(val_lp_ref_chosen_list) / len(val_lp_ref_chosen_list),
                    "val/lp_ref_rejected": sum(val_lp_ref_rejected_list) / len(val_lp_ref_rejected_list)
                    }, step=global_step)
                model.train()

[34m[1mwandb[0m: Currently logged in as: [33mgrach0v[0m ([33mcowboy_bebop[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/1:   0%|          | 0/7482 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Cac

In [None]:
model.eval()
test_losses = []
with torch.no_grad():
    for tb in test_loader:
        p_in, c_id, c_m, r_id, r_m = tb
        lt_last, lt_pkv = get_prompt_cache(p_in, adapters_active=True)
        lr_last, lr_pkv = get_prompt_cache(p_in, adapters_active=False)

        lt_c = sequence_logprob_from_cache(lt_last, lt_pkv, c_id, c_m, adapters_active=True)
        lt_r = sequence_logprob_from_cache(lt_last, lt_pkv, r_id, r_m, adapters_active=True)
        lr_c = sequence_logprob_from_cache(lr_last, lr_pkv, c_id, c_m, adapters_active=False)
        lr_r = sequence_logprob_from_cache(lr_last, lr_pkv, r_id, r_m, adapters_active=False)

        test_losses.append(orpo_loss(lt_c, lt_r, lr_c, lr_r).item())

wandb.log({"test/orpo_loss": sum(test_losses) / len(test_losses)})

# ───────── save adapters & finish ─────────
model.save_pretrained(OUTPUT_DIR)
wandb.finish()