In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification
)

In [2]:
model_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Policy model
policy = AutoModelForCausalLM.from_pretrained(model_name)
policy_ref = AutoModelForCausalLM.from_pretrained(model_name)
policy.train()
policy_ref.eval()

# Reward model
reward_model = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=1
)
reward_model.eval()

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-3.2-1B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


LlamaForSequenceClassification(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
   

In [8]:
prompt = "I am studying"
inputs = tokenizer(
    prompt, return_tensors="pt", padding=True
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

with torch.no_grad():
    gen_ids = policy.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=20,
        do_sample=True,
        top_k=50,
        temperature=1.0,
        pad_token_id=tokenizer.pad_token_id
    )
response_ids = gen_ids[:, input_ids.shape[-1]:]
query_response = torch.cat([input_ids, response_ids], dim=1)

In [9]:
# log-probs
def get_log_prob_sum(model, input_ids):
    labels = input_ids.clone()
    with torch.no_grad():
        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss  # average negative log-likelihood
    return -loss  # return log-likelihood

logprob_policy = get_log_prob_sum(policy, query_response)
logprob_ref = get_log_prob_sum(policy_ref, query_response)

# reward score
with torch.no_grad():
    reward_inputs = tokenizer(
        tokenizer.decode(query_response[0], skip_special_tokens=True),
        return_tensors="pt", truncation=True, padding=True
    )
    reward = reward_model(**reward_inputs).logits.squeeze().detach()

In [10]:
# advantage, PPO loss, KL loss
baseline = reward.detach()
advantage = reward - baseline
log_ratio = logprob_policy - logprob_ref
ratio = torch.exp(log_ratio)

# PPO-clip loss
clip_eps = 0.2
loss1 = ratio * advantage
loss2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantage
ppo_clip_loss = -torch.min(loss1, loss2)

# KL loss (optional penalty)
kl_loss = torch.mean(log_ratio**2)

# Loss
kl_coef = 0.01  # KL-Pen
ppo_loss = ppo_clip_loss + kl_coef * kl_loss

In [11]:
print("=" * 50)
print(f"[Prompt]: {prompt}")
print(f"[Response]: {tokenizer.decode(response_ids[0], skip_special_tokens=True)}")
print(f"[Reward]: {reward.item():.4f}")
print(f"[Policy logprob]: {logprob_policy.item():.4f} | [Ref logprob]: {logprob_ref.item():.4f}")
print(f"[PPO Loss]: {ppo_loss.item():.4f}")
print("=" * 50)

[Prompt]: I am studying
[Response]:  1 Peter, chapters 1-5. It has been a great journey. I can’t
[Reward]: -2.5654
[Policy logprob]: -2.4995 | [Ref logprob]: -2.4995
[PPO Loss]: 0.0000
