## Imports

In [4]:
import os

from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import (
    LlavaNextForConditionalGeneration,
    LlavaNextProcessor,
    get_cosine_schedule_with_warmup,
)
from peft import LoraConfig, get_peft_model

import wandb

## Config

In [5]:
from config import (
    MODEL_ID,
    DATASET_NAME,
    WANDB_PROJECT,
    OUTPUT_DIR,
    # USE_QLORA,
    TRAIN_BATCH_SIZE,
    VAL_BATCH_SIZE,
    TEST_BATCH_SIZE,
    GRAD_ACC_STEPS,
    EPOCHS,
    LEARNING_RATE,
    WARMUP_RATIO,
    VAL_RATIO,
    TEST_RATIO,
    LORA_R,
    LORA_ALPHA,
    LORA_DROPOUT,
    ORPO_LAMBDA,
    LOG_EVERY_STEPS,
    VAL_EVERY_STEPS,
    DEVICE,
    MAX_ANSWER_TOKENS,
)

## Load the model

In [6]:
processor = LlavaNextProcessor.from_pretrained(MODEL_ID, use_fast=True)
TOKENIZER = processor.tokenizer
EOS_ID = TOKENIZER.eos_token_id

base_model = LlavaNextForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
)

# if USE_QLORA:
#     # 4‑bit base + gradient‑ckpt
#     bnb_config = BitsAndBytesConfig(
#         load_in_4bit=True,
#         bnb_4bit_quant_type="nf4",
#         bnb_4bit_compute_dtype=torch.bfloat16,
#     )

#     base_model = LlavaNextForConditionalGeneration.from_pretrained(
#         MODEL_ID,
#         quantization_config=bnb_config,
#         torch_dtype=torch.float16,
#         low_cpu_mem_usage=True,
#         device_map="auto",
#     )

#     base_model = prepare_model_for_kbit_training(
#         base_model, use_gradient_checkpointing=True
#     )

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()

trainable params: 22,151,168 || all params: 7,588,898,816 || trainable%: 0.2919


## Create dataloader

In [8]:
from dataloader_helper import collate_fn
from functools import partial

collate_fn = partial(
    collate_fn,
    processor=processor,
    DEVICE=DEVICE,
    MAX_ANSWER_TOKENS=MAX_ANSWER_TOKENS,
    TOKENIZER=TOKENIZER,
    EOS_ID=EOS_ID,
)

In [9]:
raw_dataset = load_dataset(DATASET_NAME, split="train[:5%]")
first_split = raw_dataset.train_test_split(test_size=VAL_RATIO + TEST_RATIO, seed=42)
train_dataset = first_split["train"]
val_test_dataset = first_split["test"]
val_fraction_of_tmp = VAL_RATIO / (VAL_RATIO + TEST_RATIO)
second_split = val_test_dataset.train_test_split(test_size=1 - val_fraction_of_tmp, seed=42)
val_dataset = second_split["train"]
test_dataset = second_split["test"]

train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

In [10]:
# Save train_dataset, val_dataset and test_dataset to OUTPUT_DIR
train_dataset.save_to_disk(f"{OUTPUT_DIR}/train_dataset")
val_dataset.save_to_disk(f"{OUTPUT_DIR}/val_dataset")
test_dataset.save_to_disk(f"{OUTPUT_DIR}/test_dataset")

print(f"Datasets saved to {OUTPUT_DIR}")
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Validation dataset: {len(val_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")

Saving the dataset (0/2 shards):   0%|          | 0/3117 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/208 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/832 [00:00<?, ? examples/s]

Datasets saved to ../logs/
Train dataset: 3117 samples
Validation dataset: 208 samples
Test dataset: 832 samples


# Get logits

In [11]:
from orpo_helper import answer_logits, loss_orpo

In [12]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0.0)
steps_per_epoch = len(train_loader) // GRAD_ACC_STEPS
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    int(steps_per_epoch * EPOCHS * WARMUP_RATIO),
    steps_per_epoch * EPOCHS,
)

wandb.init(project=WANDB_PROJECT, config={k: v for k, v in globals().items() if k.isupper()})

wandb.watch(model, log="gradients", log_freq=LOG_EVERY_STEPS)

best_val = float("inf")   # lower is better for ORPO
best_step = -1


model.train()
acc_steps = 0
running_loss = 0.0  # To accumulate loss for average

for epoch in range(EPOCHS):
    for global_step, batch in tqdm(enumerate(train_loader, 1), total=len(train_loader), desc=f"Epoch {epoch+1}/{EPOCHS}"):
        prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask = batch

        chosen_logits, rejected_logits = answer_logits(model, prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask)

        loss_orpo_val, loss_sft_val, loss_or_val = loss_orpo(
            chosen_logits, 
            rejected_logits, 
            chosen_ids, 
            rejected_ids, 
            chosen_mask, 
            rejected_mask, 
            ORPO_LAMBDA
        )
        
        loss = loss_orpo_val / GRAD_ACC_STEPS 
        loss.backward()

        acc_steps += 1
        running_loss += loss.item() * GRAD_ACC_STEPS # Accumulate original loss before division
        
        if acc_steps == GRAD_ACC_STEPS:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step() 
            scheduler.step() 
            optimizer.zero_grad() 
            acc_steps = 0
            

        if global_step % LOG_EVERY_STEPS == 0:
            avg_loss = running_loss / LOG_EVERY_STEPS
            wandb.log(
                {
                    "train/orpo_loss": avg_loss,
                    "lr": scheduler.get_last_lr()[0],
                    "train/loss_orpo": loss_orpo_val.item(),
                    "train/loss_sft": loss_sft_val.item(),
                    "train/loss_or": loss_or_val.item(),
                },
                step=global_step
            )
            running_loss = 0.0 # Reset running loss


        if global_step % VAL_EVERY_STEPS == 0 or global_step == len(train_loader) - 1:
            model.eval()
            val_orpo_list, val_sft_list, val_or_list = [], [], []

            with torch.no_grad():
                for batch in tqdm(val_loader, desc="Validation"):
                    prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask = batch

                    # single prompt pass, same as training
                    chosen_logits, rejected_logits = answer_logits(
                        model,
                        prompt_inputs,
                        chosen_ids,   chosen_mask,
                        rejected_ids, rejected_mask
                    )

                    loss_orpo_val, loss_sft_val, loss_or_val = loss_orpo(
                        chosen_logits, rejected_logits,
                        chosen_ids, rejected_ids,
                        chosen_mask, rejected_mask,
                        ORPO_LAMBDA
                    )

                    val_orpo_list.append(loss_orpo_val.item())
                    val_sft_list.append(loss_sft_val.item())
                    val_or_list.append(loss_or_val.item())

            mean_val_orpo = sum(val_orpo_list) / len(val_orpo_list)

            # checkpoint if best
            if mean_val_orpo < best_val:
                best_val = mean_val_orpo
                best_step = global_step
                ckpt_dir = f"{OUTPUT_DIR}/step_{wandb.run.name}_{best_step}"
                os.makedirs(ckpt_dir, exist_ok=True)
                model.save_pretrained(ckpt_dir)
                wandb.run.summary.update({
                    "best_val_loss": best_val,
                    "best_step": best_step
                })
                print(f"★ New best val_loss {best_val:.4f} at step {best_step} — adapters saved to {ckpt_dir}")

            wandb.log({
                "val/orpo_loss": mean_val_orpo,
                "val/loss_sft":  sum(val_sft_list) / len(val_sft_list),
                "val/loss_or":   sum(val_or_list)  / len(val_or_list),
            }, step=global_step)

            model.train()

[34m[1mwandb[0m: Currently logged in as: [33mgrach0v[0m ([33mcowboy_bebop[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/1:   0%|          | 0/3117 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

★ New best val_loss 13.6699 at step 200 — adapters saved to ../logs//step_200


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

★ New best val_loss 13.6022 at step 400 — adapters saved to ../logs//step_400


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

★ New best val_loss 13.2326 at step 1400 — adapters saved to ../logs//step_1400


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Validation:   0%|          | 0/35 [00:00<?, ?it/s]

In [13]:
ckpt_dir = f"{OUTPUT_DIR}/last"
os.makedirs(ckpt_dir, exist_ok=True)
model.save_pretrained(ckpt_dir)

In [None]:
# Load the best saved model
if best_step != -1:
    best_ckpt_dir = f"{OUTPUT_DIR}/step_{best_step}"
    print(f"Loading best model from step {best_step}: {best_ckpt_dir}")
    model.load_adapter(best_ckpt_dir, adapter_name="best")
    model.set_adapter("best")
else:
    print("No best model found, using current model state")

# Test with adapter
print("Testing with adapter...")
model.eval()
test_orpo_list, test_sft_list, test_or_list = [], [], []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing with adapter"):
        prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask = batch

        chosen_logits, rejected_logits = answer_logits(
            model, prompt_inputs,
            chosen_ids, chosen_mask,
            rejected_ids, rejected_mask
        )

        loss_orpo_val, loss_sft_val, loss_or_val = loss_orpo(
            chosen_logits, rejected_logits,
            chosen_ids, rejected_ids,
            chosen_mask, rejected_mask,
            ORPO_LAMBDA
        )

        test_orpo_list.append(loss_orpo_val.item())
        test_sft_list.append(loss_sft_val.item())
        test_or_list.append(loss_or_val.item())

mean_test_orpo = sum(test_orpo_list) / len(test_orpo_list)
mean_test_sft = sum(test_sft_list) / len(test_sft_list)
mean_test_or = sum(test_or_list) / len(test_or_list)

print(f"Test Results (with adapter):")
print(f"ORPO Loss: {mean_test_orpo:.4f}")
print(f"SFT Loss: {mean_test_sft:.4f}")
print(f"OR Loss: {mean_test_or:.4f}")

Loading best model from step 1400: ../logs//step_1400
Testing with adapter...


Testing with adapter:   0%|          | 0/139 [00:00<?, ?it/s]

Test Results (with adapter):
ORPO Loss: 13.5965
SFT Loss: 7.4400
OR Loss: 0.6156

Testing without adapter (base model)...


ValueError: No adapter loaded. Please load an adapter first.

In [None]:
# Test without adapter (base model)
print("\nTesting without adapter (base model)...")

base_test_orpo_list, base_test_sft_list, base_test_or_list = [], [], []

with torch.no_grad(), model.disable_adapter():
    for batch in tqdm(test_loader, desc="Testing without adapter"):
        prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask = batch

        chosen_logits, rejected_logits = answer_logits(
            model, prompt_inputs,
            chosen_ids, chosen_mask,
            rejected_ids, rejected_mask
        )

        loss_orpo_val, loss_sft_val, loss_or_val = loss_orpo(
            chosen_logits, rejected_logits,
            chosen_ids, rejected_ids,
            chosen_mask, rejected_mask,
            ORPO_LAMBDA
        )

        base_test_orpo_list.append(loss_orpo_val.item())
        base_test_sft_list.append(loss_sft_val.item())
        base_test_or_list.append(loss_or_val.item())

base_mean_test_orpo = sum(base_test_orpo_list) / len(base_test_orpo_list)
base_mean_test_sft = sum(base_test_sft_list) / len(base_test_sft_list)
base_mean_test_or = sum(base_test_or_list) / len(base_test_or_list)

print(f"Test Results (without adapter):")
print(f"ORPO Loss: {base_mean_test_orpo:.4f}")
print(f"SFT Loss: {base_mean_test_sft:.4f}")
print(f"OR Loss: {base_mean_test_or:.4f}")

# Calculate improvements
print(f"\nImprovement with adapter:")
print(f"ORPO Loss improvement: {base_mean_test_orpo - mean_test_orpo:.4f}")
print(f"SFT Loss improvement: {base_mean_test_sft - mean_test_sft:.4f}")
print(f"OR Loss improvement: {base_mean_test_or - mean_test_or:.4f}")

# Log results to wandb
wandb.log({
    "test/orpo_loss_with_adapter": mean_test_orpo,
    "test/loss_sft_with_adapter": mean_test_sft,
    "test/loss_or_with_adapter": mean_test_or,
    "test/orpo_loss_without_adapter": base_mean_test_orpo,
    "test/loss_sft_without_adapter": base_mean_test_sft,
    "test/loss_or_without_adapter": base_mean_test_or,
    "test/orpo_improvement": base_mean_test_orpo - mean_test_orpo,
    "test/sft_improvement": base_mean_test_sft - mean_test_sft,
    "test/or_improvement": base_mean_test_or - mean_test_or,
})


wandb.finish()


Testing without adapter (base model)...


Testing without adapter:   0%|          | 0/139 [00:00<?, ?it/s]

Test Results (without adapter):
ORPO Loss: 14.3907
SFT Loss: 7.3019
OR Loss: 0.7089

Improvement with adapter:
ORPO Loss improvement: 0.7942
SFT Loss improvement: -0.1381
OR Loss improvement: 0.0932


ValueError: No adapter loaded. Please load an adapter first.

In [22]:
# Load the best saved model
last_ckpt_dir = f"{OUTPUT_DIR}/last"
model.load_adapter(last_ckpt_dir, adapter_name="last")
model.set_adapter("last")

# Test with adapter
print("Testing with adapter...")
model.eval()
test_orpo_list, test_sft_list, test_or_list = [], [], []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing with adapter"):
        prompt_inputs, chosen_ids, chosen_mask, rejected_ids, rejected_mask = batch

        chosen_logits, rejected_logits = answer_logits(
            model, prompt_inputs,
            chosen_ids, chosen_mask,
            rejected_ids, rejected_mask
        )

        loss_orpo_val, loss_sft_val, loss_or_val = loss_orpo(
            chosen_logits, rejected_logits,
            chosen_ids, rejected_ids,
            chosen_mask, rejected_mask,
            ORPO_LAMBDA
        )

        test_orpo_list.append(loss_orpo_val.item())
        test_sft_list.append(loss_sft_val.item())
        test_or_list.append(loss_or_val.item())

mean_test_orpo = sum(test_orpo_list) / len(test_orpo_list)
mean_test_sft = sum(test_sft_list) / len(test_sft_list)
mean_test_or = sum(test_or_list) / len(test_or_list)

print(f"Test Results (with adapter):")
print(f"ORPO Loss: {mean_test_orpo:.4f}")
print(f"SFT Loss: {mean_test_sft:.4f}")
print(f"OR Loss: {mean_test_or:.4f}")


Testing with adapter...


Testing with adapter:   0%|          | 0/139 [00:00<?, ?it/s]

Test Results (with adapter):
ORPO Loss: 13.1854
SFT Loss: 7.2381
OR Loss: 0.5947
