In [2]:
import torch

def dpo_loss_bradley_terry(logp_ref, logp_train, x, y_win, y_lose, beta=0.1):
    return -torch.nn.functional.logsigmoid(
        beta * (logp_train(x, y_win) - logp_ref(x, y_win))
        - beta * (logp_train(x, y_lose) - logp_ref(x, y_lose))
    )

def dpo_loss_plackett_luce(logp_ref, logp_train, x, y_best_to_worst, beta=0.1):
    N = len(y_best_to_worst)
    reward = torch.empty(N, dtype=torch.float32)
    log_softmaxes = torch.empty(N, dtype=torch.float32)
    for i, y in enumerate(y_best_to_worst):
        reward[-i] = beta * (logp_train(x, y) - logp_ref(x, y))
        log_softmaxes[-i] = torch.nn.functional.log_softmax(reward[-i:])[0]
    return -torch.sum(log_softmaxes)

In [4]:
# Off-policy setup: fine-tune model until close to reference model, then use on-policy setup
# Interactive setup: run policy to get rankings, then use on-policy setup
def interactive_dpo(ref, train, prompts):
    dataset = []
    for x in prompts:
        y0, y1 = ref(x), ref(x)
        winner = int(input(f"Which is better? 0:{y0} or 1:{y1}?"))
        y_win, y_lose = y0, y1 if winner == 0 else y1, y0
        dataset.append((x, y_win, y_lose))
    return dataset

In [5]:
# On-policy setup
# x: tokenized prompt, y: continuation
def logp_model(model):
    def logp(x, y):
        # Concat in the sequence dimension
        logits = model(torch.cat((x, y), -2)).logits()
        # Regularize the logits in the vocabulary dimension. Skip the logits of the prompt, accounting for that
        # logits are not output for the first token.
        regularized = torch.nn.functional.log_softmax(logits[:, x.shape[1]-1:, :], -1)
        # Sum logits over the sequence dimension to get logp of the entire output.
        return torch.sum(regularized, -2)
    return logp

def on_policy_dpo(ref, train, dataset):
    # Ref should be frozen, typically a copy of the model to be trained.
    for param in ref.parameters():
        param.requires_grad_(False)
    logp_ref = logp_model(ref)
    logp_train = logp_model(train)
    optimizer = torch.optim.AdamW(train.parameters())
    for x, y_win, y_lose in dataset:
        loss = dpo_loss_bradley_terry(logp_ref, logp_train, x, y_win, y_lose)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
