In [17]:
!pip install transformers trl datasets huggingface_hub bitsandbytes wandb tqdm pillow torchvision peft ipywidgets nbformat vllm

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting vllm
Collecting vllm
  Downloading vllm-0.9.2-cp38-abi3-manylinux1_x86_64.whl.metadata (15 kB)
  Downloading vllm-0.9.2-cp38-abi3-manylinux1_x86_64.whl.metadata (15 kB)
Collecting cachetools (from vllm)
Collecting cachetools (from vllm)
  Downloading cachetools-6.1.0-py3-none-any.whl.metadata (5.4 kB)
  Downloading cachetools-6.1.0-py3-none-any.whl.metadata (5.4 kB)
Collecting sentencepiece (from vllm)
  Downloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting sentencepiece (from vllm)
  Downloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting blake3 (from vllm)
  Downloading blake3-1.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting blake3 (from vllm)
  Downloading blake3-1.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting py-cpuinfo (from vllm)
  Downloading py_cpuin

## Mount Drive (optional)

In [2]:
# HF_CACHE = "/content/drive/MyDrive/hf_cache"

In [3]:
# from google.colab import drive, runtime
# import os
# drive.mount("/content/drive")

# # One shared cache for everything:
# HF_CACHE = "/content/drive/MyDrive/hf_cache"
# !mkdir -p "$HF_CACHE"

# os.environ["HF_HOME"] = HF_CACHE           # generic root
# os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
# os.environ["HF_DATASETS_CACHE"] = HF_CACHE

## Imports

In [4]:
from __future__ import annotations

import contextlib
import os
from pathlib import Path
from typing import Dict, List, Tuple

from tqdm.auto import tqdm

import torch
from torch import Tensor
from torch.nn.functional import log_softmax, softplus
from torch.utils.data import DataLoader
from PIL import Image
from datasets import load_dataset
from transformers import (
    BitsAndBytesConfig,
    LlavaNextForConditionalGeneration,
    LlavaNextProcessor,
    get_cosine_schedule_with_warmup,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import wandb

## Config

In [5]:
MODEL_ID          = "llava-hf/llava-v1.6-mistral-7b-hf"
DATASET_NAME      = "openbmb/RLAIF-V-Dataset"
WANDB_PROJECT     = "llava-qlora-orpo"
OUTPUT_DIR        = "../logs/"

!mkdir -p "$OUTPUT_DIR"
# USE_QLORA         = False 
# QLORA had a lot of issues with the Mistral model
# SO I dropped it for now

TRAIN_BATCH_SIZE  = 1 
VAL_BATCH_SIZE    = 6
TEST_BATCH_SIZE   = 6
GRAD_ACC_STEPS    = 4          # effective batch = TRAIN_BATCH_SIZE × GRAD_ACC_STEPS
EPOCHS            = 1
LEARNING_RATE     = 2e-4
WARMUP_RATIO      = 0.03

VAL_RATIO         = 0.05      
TEST_RATIO        = 0.20  

LORA_R            = 8  
LORA_ALPHA        = 16 

LORA_DROPOUT      = 0.05
ORPO_LAMBDA       = 5

LOG_EVERY_STEPS   = 4
VAL_EVERY_STEPS   = 200

DEVICE = torch.device("cuda")

# Maximum number of tokens for the answer
MAX_ANSWER_TOKENS = 128 

## Load the model

In [None]:
processor = LlavaNextProcessor.from_pretrained(MODEL_ID, use_fast=True)
TOKENIZER = processor.tokenizer
EOS_ID = TOKENIZER.eos_token_id

base_model = LlavaNextForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
)

# if USE_QLORA:
#     # 4‑bit base + gradient‑ckpt
#     bnb_config = BitsAndBytesConfig(
#         load_in_4bit=True,
#         bnb_4bit_quant_type="nf4",
#         bnb_4bit_compute_dtype=torch.bfloat16,
#     )

#     base_model = LlavaNextForConditionalGeneration.from_pretrained(
#         MODEL_ID,
#         quantization_config=bnb_config,
#         torch_dtype=torch.float16,
#         low_cpu_mem_usage=True,
#         device_map="auto",
#     )

#     base_model = prepare_model_for_kbit_training(
#         base_model, use_gradient_checkpointing=True
#     )

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()

trainable params: 22,151,168 || all params: 7,588,898,816 || trainable%: 0.2919


## Create dataloader

In [None]:
@contextlib.contextmanager
def temporary_padding_side(tokenizer, side):
    """Temporarily change padding side (left/right) inside a `with` block."""
    original = tokenizer.padding_side
    tokenizer.padding_side = side
    try:
        yield
    finally:
        tokenizer.padding_side = original


def build_prompt_inputs(images, questions):
    """Tokenise the (question + image placeholder) prompt with left‑padding."""
    conversations = [
        [{"role": "user", "content": [{"type": "text", "text": q}, {"type": "image"}]}]
        for q in questions
    ]
    prompts = [processor.apply_chat_template(c, add_generation_prompt=True) for c in conversations]
    encoded = processor(images=images, text=prompts, padding=True, return_tensors="pt")
    return {k: v.to(DEVICE) for k, v in encoded.items()}


def tokenize_answers(texts, max_length):
    """Right‑pad assistant answers and append EOS."""

    # Reason for the right pad tokenization -
    # later I will concatenate prompt tokens and potential answer tokens,
    # to get logits in one go, without writing a loop.
    # Having pad tokens in the middle seems very confusing
    # and can be misleading and couse errors in the future.

    with temporary_padding_side(TOKENIZER, "right"):
        encoded = TOKENIZER(
            texts,
            padding=True,
            truncation=True,
            add_special_tokens=False,
            return_tensors="pt",
        )
    ids, mask = encoded["input_ids"], encoded["attention_mask"]
    eos_column = torch.full((ids.size(0), 1), EOS_ID, dtype=ids.dtype)
    ids = torch.cat([ids, eos_column], dim=1)
    mask = torch.cat([mask, torch.ones_like(eos_column)], dim=1)

    if max_length is not None:
        # Trim if longer than max_length
        ids = ids[:, :max_length]
        mask = mask[:, :max_length]

    return ids.to(DEVICE), mask.to(DEVICE)


def collate_fn(batch):
    images = [item["image"] for item in batch]
    questions = [item["question"] for item in batch]
    chosen_texts = [item["chosen"] for item in batch]
    rejected_texts = [item["rejected"] for item in batch]

    prompt_inputs = build_prompt_inputs(images, questions)
    chosen_ids, chosen_mask = tokenize_answers(chosen_texts, max_length=MAX_ANSWER_TOKENS)
    rejected_ids, rejected_mask = tokenize_answers(rejected_texts, max_length=MAX_ANSWER_TOKENS)

    return prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask

In [9]:
raw_dataset = load_dataset(DATASET_NAME, split="train[:5%]")
first_split = raw_dataset.train_test_split(test_size=VAL_RATIO + TEST_RATIO, seed=42)
train_dataset = first_split["train"]
val_test_dataset = first_split["test"]
val_fraction_of_tmp = VAL_RATIO / (VAL_RATIO + TEST_RATIO)
second_split = val_test_dataset.train_test_split(test_size=1 - val_fraction_of_tmp, seed=42)
val_dataset = second_split["train"]
test_dataset = second_split["test"]

train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

# Get logits

In [None]:
def get_prompt_cache(prompt_inputs):
    """Get the last logits and past key values for the prompt inputs."""
    
    training  = model.training

    model.eval()
    with torch.no_grad():
        out = model(**prompt_inputs, return_dict=True)

    if training:
        model.train()

    return out.logits[:, -1:, :], out.past_key_values

def answer_logits(prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask):
    """Get logits for the chosen and rejected answers, aligned with the prompt inputs."""

    last_logits, prompt_kv = get_prompt_cache(prompt_inputs)

    # raw logits when we feed the full answers
    raw_chosen = model(
        input_ids=chosen_ids,
        attention_mask=chosen_mask,
        past_key_values=prompt_kv
    ).logits          # (B,N,V)

    raw_rejected = model(
        input_ids=rejected_ids,
        attention_mask=rejected_mask,
        past_key_values=prompt_kv
    ).logits          # (B,N,V)

    # align: prepend last_prompt_logits and drop the last timestep
    chosen_logits = torch.cat([last_logits, raw_chosen[:, :-1, :]],  dim=1)
    rejected_logits = torch.cat([last_logits, raw_rejected[:, :-1, :]], dim=1)

    return chosen_logits, rejected_logits

In [None]:
def token_logp(logits, ids):
    """Get log probabilities for the given token IDs from the logits."""
    logp = log_softmax(logits, dim=-1)
    return logp.gather(2, ids.unsqueeze(-1)).squeeze(-1)    # (B,N)

def log_prob(logs, mask):
    """Get the average log probability for the given logs and mask."""
    return (logs * mask).sum(dim=-1) / mask.sum(dim=-1)

def log_odds(log_prob):
    """Convert log probability to log odds."""
    return log_prob - torch.log1p(-torch.exp(log_prob))


In [None]:
def loss_orpo(chosen_logits, rejected_logits, chosen_ids, rejected_ids, chosen_mask, rejected_mask, lam):
    """Calculate the ORPO loss for the chosen and rejected logits."""
    chosen_logits = token_logp(chosen_logits, chosen_ids)   # (B,N)
    rejected_logits = token_logp(rejected_logits, rejected_ids)  # (B,N)

    chosen_logp = log_prob(chosen_logits, chosen_mask)  # (B,)
    rejected_logp = log_prob(rejected_logits, rejected_mask)  # (B,)
    
    log_odds_chosen = log_odds(chosen_logp)  # (B,)
    log_odds_rejected = log_odds(rejected_logp)  # (B,)

    L_sft = -chosen_logp.mean()  # supervised fine-tuning loss
    L_or = -torch.log(
        torch.sigmoid(log_odds_chosen - log_odds_rejected)
    ).mean()


    return L_sft + lam * L_or, L_sft, L_or

In [13]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0.0)
steps_per_epoch = len(train_loader) // GRAD_ACC_STEPS
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    int(steps_per_epoch * EPOCHS * WARMUP_RATIO),
    steps_per_epoch * EPOCHS,
)

wandb.init(project=WANDB_PROJECT, config={k: v for k, v in globals().items() if k.isupper()})

wandb.watch(model, log="gradients", log_freq=LOG_EVERY_STEPS)

best_val = float("inf")   # lower is better for ORPO
best_step = -1


model.train()
acc_steps = 0
running_loss = 0.0  # To accumulate loss for average

for epoch in range(EPOCHS):
    for global_step, batch in tqdm(enumerate(train_loader, 1), total=len(train_loader), desc=f"Epoch {epoch+1}/{EPOCHS}"):
        prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask = batch

        chosen_logits, rejected_logits = answer_logits(prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask)

        loss_orpo_val, loss_sft_val, loss_or_val = loss_orpo(
            chosen_logits, 
            rejected_logits, 
            chosen_ids, 
            rejected_ids, 
            chosen_mask, 
            rejected_mask, 
            ORPO_LAMBDA
        )
        
        loss = loss_orpo_val / GRAD_ACC_STEPS 
        loss.backward()

        acc_steps += 1
        running_loss += loss.item() * GRAD_ACC_STEPS # Accumulate original loss before division
        
        if acc_steps == GRAD_ACC_STEPS:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step() 
            scheduler.step() 
            optimizer.zero_grad() 
            acc_steps = 0
            

        if global_step % LOG_EVERY_STEPS == 0:
            avg_loss = running_loss / LOG_EVERY_STEPS
            wandb.log(
                {
                    "train/orpo_loss": avg_loss,
                    "lr": scheduler.get_last_lr()[0],
                    "train/loss_orpo": loss_orpo_val.item(),
                    "train/loss_sft": loss_sft_val.item(),
                    "train/loss_or": loss_or_val.item(),
                },
                step=global_step
            )
            running_loss = 0.0 # Reset running loss


        if global_step % VAL_EVERY_STEPS == 0 or global_step == len(train_loader) - 1:
            model.eval()
            val_orpo_list, val_sft_list, val_or_list = [], [], []

            with torch.no_grad():
                for batch in tqdm(val_loader, desc="Validation"):
                    prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask = batch

                    # single prompt pass, same as training
                    chosen_logits, rejected_logits = answer_logits(
                        prompt_inputs,
                        chosen_ids,   chosen_mask,
                        rejected_ids, rejected_mask
                    )

                    loss_orpo_val, loss_sft_val, loss_or_val = loss_orpo(
                        chosen_logits, rejected_logits,
                        chosen_ids, rejected_ids,
                        chosen_mask, rejected_mask,
                        ORPO_LAMBDA
                    )

                    val_orpo_list.append(loss_orpo_val.item())
                    val_sft_list.append(loss_sft_val.item())
                    val_or_list.append(loss_or_val.item())

            mean_val_orpo = sum(val_orpo_list) / len(val_orpo_list)

            # checkpoint if best
            if mean_val_orpo < best_val:
                best_val = mean_val_orpo
                best_step = global_step
                ckpt_dir = f"{OUTPUT_DIR}/step_{best_step}"
                os.makedirs(ckpt_dir, exist_ok=True)
                model.save_pretrained(ckpt_dir)
                wandb.run.summary.update({
                    "best_val_loss": best_val,
                    "best_step": best_step
                })
                print(f"★ New best val_loss {best_val:.4f} at step {best_step} — adapters saved to {ckpt_dir}")

            wandb.log({
                "val/orpo_loss": mean_val_orpo,
                "val/loss_sft":  sum(val_sft_list) / len(val_sft_list),
                "val/loss_or":   sum(val_or_list)  / len(val_or_list),
            }, step=global_step)

            model.train()

[34m[1mwandb[0m: Currently logged in as: [33mgrach0v[0m ([33mcowboy_bebop[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/1:   0%|          | 0/3117 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

★ New best val_loss 10.5060 at step 200 — adapters saved to ../logs//step_200


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

★ New best val_loss 10.3085 at step 600 — adapters saved to ../logs//step_600


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

★ New best val_loss 10.0305 at step 1000 — adapters saved to ../logs//step_1000


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

★ New best val_loss 9.9580 at step 1400 — adapters saved to ../logs//step_1400


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

★ New best val_loss 9.8866 at step 1600 — adapters saved to ../logs//step_1600


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

★ New best val_loss 9.8009 at step 1800 — adapters saved to ../logs//step_1800


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

★ New best val_loss 9.7561 at step 2000 — adapters saved to ../logs//step_2000


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

★ New best val_loss 9.7541 at step 2200 — adapters saved to ../logs//step_2200


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

★ New best val_loss 9.6849 at step 2600 — adapters saved to ../logs//step_2600


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

In [14]:
ckpt_dir = f"{OUTPUT_DIR}/last"
os.makedirs(ckpt_dir, exist_ok=True)
model.save_pretrained(ckpt_dir)

In [15]:
# Load the best saved model
if best_step != -1:
    best_ckpt_dir = f"{OUTPUT_DIR}/step_{best_step}"
    print(f"Loading best model from step {best_step}: {best_ckpt_dir}")
    model.load_adapter(best_ckpt_dir, adapter_name="best")
    model.set_adapter("best")
else:
    print("No best model found, using current model state")

model.eval()
test_orpo_list, test_sft_list, test_or_list = [], [], []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask = batch

        chosen_logits, rejected_logits = answer_logits(
            prompt_inputs,
            chosen_ids, chosen_mask,
            rejected_ids, rejected_mask
        )

        loss_orpo_val, loss_sft_val, loss_or_val = loss_orpo(
            chosen_logits, rejected_logits,
            chosen_ids, rejected_ids,
            chosen_mask, rejected_mask,
            ORPO_LAMBDA
        )

        test_orpo_list.append(loss_orpo_val.item())
        test_sft_list.append(loss_sft_val.item())
        test_or_list.append(loss_or_val.item())

mean_test_orpo = sum(test_orpo_list) / len(test_orpo_list)
mean_test_sft = sum(test_sft_list) / len(test_sft_list)
mean_test_or = sum(test_or_list) / len(test_or_list)

print(f"Test Results:")
print(f"ORPO Loss: {mean_test_orpo:.4f}")
print(f"SFT Loss: {mean_test_sft:.4f}")
print(f"OR Loss: {mean_test_or:.4f}")

wandb.log({
    "test/orpo_loss": mean_test_orpo,
    "test/loss_sft": mean_test_sft,
    "test/loss_or": mean_test_or,
})

# ───────── save adapters & finish ─────────
model.save_pretrained(OUTPUT_DIR)
wandb.finish()

Loading best model from step 2600: ../logs//step_2600


Testing:   0%|          | 0/139 [00:00<?, ?it/s]

Test Results:
ORPO Loss: 10.1229
SFT Loss: 7.0619
OR Loss: 0.6122


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
lr,▂█████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▁▁▁▁▁▁▁
test/loss_or,▁
test/loss_sft,▁
test/orpo_loss,▁
train/loss_or,▃▃▃▂▂▂▄▂█▁▁▂▁▁▁▂▁▂▁▄▂▁▃▁▁▁▁▁▁▁▂▁▃▁▁▃▁▁▂▃
train/loss_orpo,▅▄▃▃▄▄▃▃▅▅▁▄▆▃▂▂▂█▁▂▁▂▅▃▁▂▁▁▇▄▂▂▁▁▂▃▃▂▆▁
train/loss_sft,▄▄▃▃▂▁▂▂▃▃▄▂▄▁▆█▂▁▂█▂▂▅▂▄▂▄▂▂▂▃▃▇▂▃▂▂▂▃▂
train/orpo_loss,▅▆▇▆▅▆▅▆▃▄▄▄▂▃▅▅▃▄█▅▃▃▅▂▄▂▄▃▃▄▂▂▃▄▁▁▇▂▁▁
val/loss_or,█▆▅▆▅▄▂▃▂▁▁▁▁▁▁▁
val/loss_sft,▅█▅█▁▄▄▂▁▂▂▂▁▁▁▁

0,1
best_step,2600.0
best_val_loss,9.68493
lr,0.0
test/loss_or,0.61222
test/loss_sft,7.06194
test/orpo_loss,10.12295
train/loss_or,0.19214
train/loss_orpo,4.65625
train/loss_sft,3.69727
train/orpo_loss,4.33838
