In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification
)

In [11]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token

samples = [
    {
        "prompt": "What is the capital of France?",
        "chosen": "Paris is the capital of France.",
        "rejected": "France is a city in Paris."
    },
    {
        "prompt": "What is 2 + 2?",
        "chosen": "The answer is 4.",
        "rejected": "2 plus 2 equals 22."
    }
]

In [12]:
class RewardModel(nn.Module):
    def __init__(self, model_name="meta-llama/Llama-3.2-1B"):
        super().__init__()
        self.transformer = AutoModel.from_pretrained(model_name)
        self.value_head = nn.Linear(self.transformer.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.transformer(
            input_ids=input_ids, attention_mask=attention_mask
        )
        last_hidden = outputs.last_hidden_state
        value = self.value_head(last_hidden[:, -1, :]) # B x S x E
        return value.squeeze(-1)


In [13]:
model = RewardModel()

In [14]:
def encode_batch(samples):
    chosen_texts = [s["prompt"] + " " + s["chosen"] for s in samples]
    rejected_texts = [s["prompt"] + " " + s["rejected"] for s in samples]

    chosen = tokenizer(
        chosen_texts, padding=True, truncation=True, return_tensors="pt"
    )
    rejected = tokenizer(
        rejected_texts, padding=True, truncation=True, return_tensors="pt"
    )
    return chosen, rejected

In [15]:
chosen, rejected = encode_batch(samples)
chosen = {k: v for k, v in chosen.items()}
rejected = {k: v for k, v in rejected.items()}

In [16]:
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
loss_fn = nn.MarginRankingLoss(margin=1.0)

for epoch in range(1):
    model.train()
    r_chosen = model(**chosen)
    r_rejected = model(**rejected)
    target = torch.ones_like(r_chosen)

    loss = loss_fn(r_chosen, r_rejected, target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch} | Loss: {loss.item():.4f}")

Epoch 0 | Loss: 2.2268


In [17]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Policy model
policy = AutoModelForCausalLM.from_pretrained(model_name)
policy_ref = AutoModelForCausalLM.from_pretrained(model_name)
policy.train()
policy_ref.eval()

# Reward model
reward_model = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=1
)
reward_model.eval()

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (score): Linear(in_features=768, out_features=1, bias=False)
)

In [None]:
prompt = "I am studying"
inputs = tokenizer(
    prompt, return_tensors="pt", padding=True
)
input_ids = inputs["input_ids"]

with torch.no_grad():
    gen_ids = policy.generate(
        input_ids=input_ids,
        max_new_tokens=20,
        do_sample=True,
        top_k=50,
        temperature=1.0,
        pad_token_id=tokenizer.pad_token_id
    )
response_ids = gen_ids[:, input_ids.shape[-1]:]
query_response = torch.cat([input_ids, response_ids], dim=1)

In [None]:
# log-probs
def get_log_prob_sum(model, input_ids):
    labels = input_ids.clone()
    with torch.no_grad():
        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss  # average negative log-likelihood
    return -loss  # return log-likelihood

logprob_policy = get_log_prob_sum(policy, query_response)
logprob_ref = get_log_prob_sum(policy_ref, query_response)

In [None]:
logprob_policy

In [None]:
logprob_ref

In [None]:
# reward score
with torch.no_grad():
    reward_inputs = tokenizer(
        tokenizer.decode(query_response[0], skip_special_tokens=True),
        return_tensors="pt", truncation=True, padding=True
    )
    reward = reward_model(**reward_inputs).logits.squeeze().detach()

In [None]:
reward

In [None]:
advantage

In [None]:
ppo_clip_loss

In [None]:
# advantage, PPO loss, KL loss
baseline = reward.detach()
advantage = reward - baseline
log_ratio = logprob_policy - logprob_ref
ratio = torch.exp(log_ratio)

# PPO-clip loss
clip_eps = 0.2
loss1 = ratio * advantage
loss2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantage
ppo_clip_loss = -torch.min(loss1, loss2)

# KL loss (optional penalty)
kl_loss = torch.mean(log_ratio**2)

# Loss
kl_coef = 0.01  # KL-Pen
ppo_loss = ppo_clip_loss + kl_coef * kl_loss

In [None]:
kl_coef

In [None]:
ppo_loss

In [None]:
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
# Load GPT-2
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Policy model (trainable)
policy_model = AutoModelForCausalLM.from_pretrained(model_name)
policy_model.train()

# Reference model (fixed, no grad)
reference_model = AutoModelForCausalLM.from_pretrained(model_name)
reference_model.eval()

In [None]:
sample = {
    "prompt": "The weather today is",
    "chosen": " sunny and warm.",
    "rejected": " not a banana."
}

prompt = sample["prompt"]

# Tokenize
chosen = tokenizer(
    prompt + sample["chosen"], return_tensors="pt", padding=True
)
rejected = tokenizer(
    prompt + sample["rejected"], return_tensors="pt", padding=True
)

In [None]:
# Forward pass - policy model
policy_chosen_outputs = policy_model(**chosen)
policy_rejected_outputs = policy_model(**rejected)

# Forward pass - reference model (no_grad)
with torch.no_grad():
    ref_chosen_outputs = reference_model(**chosen)
    ref_rejected_outputs = reference_model(**rejected)

In [None]:
policy_chosen_outputs.keys()

In [None]:
# Get logits
def compute_log_prob(outputs, inputs):
    logits = outputs.logits[:, :-1, :]
    labels = inputs["input_ids"][:, 1:]
    log_probs = torch.gather(
        logits.log_softmax(dim=-1), dim=2, index=labels.unsqueeze(-1)
    ).squeeze(-1).sum(dim=1)
    return log_probs

policy_chosen_logp = compute_log_prob(policy_chosen_outputs, chosen)
policy_rejected_logp = compute_log_prob(policy_rejected_outputs, rejected)
ref_chosen_logp = compute_log_prob(ref_chosen_outputs, chosen)
ref_rejected_logp = compute_log_prob(ref_rejected_outputs, rejected)

In [None]:
policy_chosen_logp

In [None]:
policy_rejected_logp

In [None]:
ref_chosen_logp

In [None]:
ref_rejected_logp

In [None]:
# DPO loss function
beta = 0.1

diff = (policy_chosen_logp - ref_chosen_logp) - (policy_rejected_logp - ref_rejected_logp)
dpo_loss = -torch.nn.functional.logsigmoid(beta * diff).mean()
dpo_loss

In [None]:
diff