In [6]:
"""
Build Direct Preference Optimization from scratch
@author: demouo
"""

# Model
from transformers import AutoModelForCausalLM, AutoTokenizer
from unsloth import FastLanguageModel


model_path = "Llama-3.2-1B-Instruct"
max_seq_length = 2056
load_in_4bit = True
# model = AutoModelForCausalLM(model_path, max_seq_length=max_seq_length)
# tokenizer = AutoTokenizer(model_path)
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_path,
    max_seq_length=max_seq_length,
    load_in_4bit=load_in_4bit
)

ref_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_path,
    max_seq_length=max_seq_length,
    load_in_4bit=load_in_4bit
)


==((====))==  Unsloth 2025.7.4: Fast Llama patching. Transformers: 4.53.2. vLLM: 0.9.2.
   \\   /|    NVIDIA GeForce RTX 3090. Num GPUs = 4. Max memory: 23.691 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
==((====))==  Unsloth 2025.7.4: Fast Llama patching. Transformers: 4.53.2. vLLM: 0.9.2.
   \\   /|    NVIDIA GeForce RTX 3090. Num GPUs = 4. Max memory: 23.691 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [None]:
# SFT model
model = FastLanguageModel.get_peft_model(
    model,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],

    r = 16,           # Larger = higher accuracy, but might overfit
    lora_alpha = 16,  # Recommended alpha == r at least
    lora_dropout = 0.1,
    bias = "none",
    random_state = 3407, # Do not modify the random_state for reproducibility
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

In [25]:
import torch.nn.functional as F
import torch

model.train()
ref_model.eval()
for param in ref_model.parameters():
    param.requires_grad = False

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-2)

# One train example
train_example = {
    "question": "I'm fat, how about you?",
    "chosen": "Oh sorry to hear, what's the point now?",
    "rejected": "Aha, so fat so funny your guy?",
}
beta = 0.1
prompt = train_example["question"]
chosen_response = train_example["chosen"]
rejected_response = train_example["rejected"]

chosen_input_text = prompt + tokenizer.eos_token + chosen_response
chosen_tokenized_text = tokenizer(chosen_input_text, return_tensors="pt").to("cuda:0")
chosen_input_ids = chosen_tokenized_text.input_ids
print(f"chosen_input_ids shape: {chosen_input_ids.shape}")
chosen_attention_mask = chosen_tokenized_text.attention_mask

rejected_input_text = prompt + tokenizer.eos_token + rejected_response
rejected_tokenized_text = tokenizer(rejected_input_text, return_tensors="pt").to(
    "cuda:0"
)
rejected_input_ids = rejected_tokenized_text.input_ids
print(f"rejected_input_ids shape: {rejected_input_ids.shape}")
rejected_attention_mask = rejected_tokenized_text.attention_mask

# Forward
# (Policy) model
# Log probabilities of chosen texts
policy_chosen_outputs = model(
    input_ids=chosen_input_ids, attention_mask=chosen_attention_mask
)
policy_chosen_log_probs = F.log_softmax(policy_chosen_outputs.logits, dim=-1)
print(f"policy_chosen_log_probs shape: {policy_chosen_log_probs.shape}")
print(f"policy_chosen_log_probs:\n {policy_chosen_log_probs}")

# Log probabilities of rejected texts
policy_rejected_outputs = model(
    input_ids=rejected_input_ids, attention_mask=rejected_attention_mask
)
policy_rejected_log_probs = F.log_softmax(policy_rejected_outputs.logits, dim=-1)
print(f"policy_rejected_log_probs shape: {policy_rejected_log_probs.shape}")
print(f"policy_rejected_log_probs:\n {policy_rejected_log_probs}")

# Reference model
with torch.no_grad():
    # Log probabilities of chosen texts
    refer_chosen_outputs = ref_model(
        input_ids=chosen_input_ids, attention_mask=chosen_attention_mask
    )
    refer_chosen_log_probs = F.log_softmax(refer_chosen_outputs.logits, dim=-1)
    print(f"refer_chosen_log_probs shape: {refer_chosen_log_probs.shape}")
    print(f"refer_chosen_log_probs:\n {refer_chosen_log_probs}")

    # Log probabilities of rejected texts
    refer_rejected_outputs = ref_model(
        input_ids=rejected_input_ids, attention_mask=rejected_attention_mask
    )
    refer_rejected_log_probs = F.log_softmax(refer_rejected_outputs.logits, dim=-1)
    print(f"refer_rejected_log_probs shape: {refer_rejected_log_probs.shape}")
    print(f"refer_rejected_log_probs:\n {refer_rejected_log_probs}")

# Extract response part only
prompt_length = tokenizer(
    prompt + tokenizer.eos_token, return_tensors="pt"
).input_ids.shape[1]
print(f"prompt_length: {prompt_length}")
policy_chosen_log_probs_response = policy_chosen_log_probs[:, prompt_length - 1 : -1, :]
print(f"policy_chosen_log_probs_response shape: {policy_chosen_log_probs_response.shape}")
refer_chosen_log_probs_response = refer_chosen_log_probs[:, prompt_length - 1 : -1, :]
print(f"refer_chosen_log_probs_response shape: {refer_chosen_log_probs_response.shape}")
policy_rejected_log_probs_response = policy_rejected_log_probs[
    :, prompt_length - 1 : -1, :
]
print(f"policy_rejected_log_probs_response shape: {policy_rejected_log_probs_response.shape}")
refer_rejected_log_probs_response = refer_rejected_log_probs[
    :, prompt_length - 1 : -1, :
]
print(f"refer_rejected_log_probs_response shape: {refer_rejected_log_probs_response.shape}")

# Targets (Labels)
labels_chosen = chosen_input_ids[:, prompt_length:]
labels_rejected = rejected_input_ids[:, prompt_length:]
print(f"labels_chosen shape: {labels_chosen.shape}")
print(f"labels_rejected shape: {labels_rejected.shape}")

def get_response_log_probs(log_probs, labels):
    """
    Gather targets token's probability from vocab_size dim

    Args:
        log_probs (tensor[batch_size, seq_length, vocab_size]): The generated logit
        labels (tensor[batch_size, seq_length]): The true labels (indies)

    Returns:
        'tensor([batch_size])': Loss of each sample
    """
    # [batch_size, seq_length] -> [batch_size, seq_length, 1]
    print(f"labels shape: {labels.shape}")
    labels = labels.unsqueeze(-1)
    print(f"After unsqueeze, labels shape: {labels.shape}")
    per_token_logps = torch.gather(log_probs, dim=2, index=labels)
    print(f"per_token_logps shape: {per_token_logps.shape}")

    # [batch_size, seq_length, 1] -> [batch_size, sequence] -> [batch_size]
    response_log_probs = per_token_logps.squeeze(-1).sum(dim=1)
    print(f"response_log_probs shape: {response_log_probs.shape}")
    return response_log_probs


# Loss
pi_chosen_log_prob = get_response_log_probs(
    policy_chosen_log_probs_response, labels_chosen
)
print(f"pi_chosen_log_prob shape: {pi_chosen_log_prob.shape}")
print(f"pi_chosen_log_prob: {pi_chosen_log_prob}")
pi_rejected_log_probs = get_response_log_probs(
    policy_rejected_log_probs_response, labels_rejected
)
print(f"pi_rejected_log_probs shape: {pi_rejected_log_probs.shape}")
print(f"pi_rejected_log_probs: {pi_rejected_log_probs}")

ref_chosen_log_prob = get_response_log_probs(
    refer_chosen_log_probs_response, labels_chosen
)
print(f"ref_chosen_log_prob shape: {ref_chosen_log_prob.shape}")
print(f"ref_chosen_log_prob: {ref_chosen_log_prob}")

ref_rejected_log_prob = get_response_log_probs(
    refer_rejected_log_probs_response, labels_rejected
)
print(f"ref_rejected_log_prob shape: {ref_rejected_log_prob.shape}")
print(f"ref_rejected_log_prob: {ref_rejected_log_prob}")

# dpo loss
policy_reward_chosen = beta * (pi_chosen_log_prob - ref_chosen_log_prob)
policy_reward_rejected = beta * (pi_rejected_log_probs - ref_rejected_log_prob)
print(f"policy_reward_chosen shape: {policy_reward_chosen.shape}")
print(f"policy_reward_rejected shape: {policy_reward_rejected.shape}")

loss = -F.logsigmoid(policy_reward_chosen - policy_reward_rejected).mean()
print(f"loss shape: {loss.shape}")
print(f"loss: {loss}")

# Backward and steps
optimizer.zero_grad()
loss.backward()
optimizer.step()

for param in model.parameters():
    if param.grad is not None:
        print(f"param shape: {param.shape}")
        print(f"param: {param}")
        print(f"param grad: {param.grad}")
        assert False

chosen_input_ids shape: torch.Size([1, 21])
rejected_input_ids shape: torch.Size([1, 20])
policy_chosen_log_probs shape: torch.Size([1, 21, 128256])
policy_chosen_log_probs:
 tensor([[[-15.7500, -14.7500, -12.5625,  ..., -14.4375, -14.4375, -14.4375],
         [-26.3750, -20.2500, -22.1250,  ..., -20.5000, -20.5000, -20.5000],
         [-25.0000, -18.7500, -20.6250,  ..., -19.1250, -19.1250, -19.1250],
         ...,
         [-27.2500, -21.1250, -22.6250,  ..., -21.0000, -21.0000, -21.0000],
         [-29.2500, -23.5000, -24.6250,  ..., -23.0000, -23.0000, -23.0000],
         [-25.6250, -19.3750, -21.1250,  ..., -19.6250, -19.6250, -19.6250]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<LogSoftmaxBackward0>)
policy_rejected_log_probs shape: torch.Size([1, 20, 128256])
policy_rejected_log_probs:
 tensor([[[-15.8125, -14.9375, -12.6250,  ..., -14.4375, -14.4375, -14.4375],
         [-26.1250, -20.1250, -21.7500,  ..., -20.2500, -20.2500, -20.2500],
         [-26.0000, -19.750

AssertionError: 