In [8]:
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
model_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Policy model (trainable)
policy_model = AutoModelForCausalLM.from_pretrained(model_name)
policy_model.train()

# Reference model (fixed, no grad)
reference_model = AutoModelForCausalLM.from_pretrained(model_name)
reference_model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [3]:
sample = {
    "prompt": "The weather today is",
    "chosen": " sunny and warm.",
    "rejected": " not a banana."
}

prompt = sample["prompt"]

# Tokenize
chosen = tokenizer(
    prompt + sample["chosen"], return_tensors="pt", padding=True
)
rejected = tokenizer(
    prompt + sample["rejected"], return_tensors="pt", padding=True
)

In [4]:
# Forward pass - policy model
policy_chosen_outputs = policy_model(**chosen)
policy_rejected_outputs = policy_model(**rejected)

# Forward pass - reference model (no_grad)
with torch.no_grad():
    ref_chosen_outputs = reference_model(**chosen)
    ref_rejected_outputs = reference_model(**rejected)

In [5]:
# Get logits
def compute_log_prob(outputs, inputs):
    logits = outputs.logits[:, :-1, :]
    labels = inputs["input_ids"][:, 1:]
    log_probs = torch.gather(
        logits.log_softmax(dim=-1), dim=2, index=labels.unsqueeze(-1)
    ).squeeze(-1).sum(dim=1)
    return log_probs

policy_chosen_logp = compute_log_prob(policy_chosen_outputs, chosen)
policy_rejected_logp = compute_log_prob(policy_rejected_outputs, rejected)
ref_chosen_logp = compute_log_prob(ref_chosen_outputs, chosen)
ref_rejected_logp = compute_log_prob(ref_rejected_outputs, rejected)

In [6]:
# DPO loss function
beta = 0.1

diff = (policy_chosen_logp - ref_chosen_logp) - (policy_rejected_logp - ref_rejected_logp)
dpo_loss = -torch.nn.functional.logsigmoid(beta * diff).mean()
dpo_loss

tensor(0.6931, grad_fn=<NegBackward0>)