In [None]:
# Direct Preference Optimization: Your Language Model is Secretly a Reward Model
# https://arxiv.org/abs/2305.18290

In [6]:
from functools import partial

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

$$
\mathcal{L}_{\mathrm{DPO}}\left(\pi_\theta ; \pi_{\mathrm{ref}}\right)=-\mathbb{E}_{\left(x, y_w, y_l\right) \sim \mathcal{D}}\left[\log \sigma\left(\beta \log \frac{\pi_\theta\left(y_w \mid x\right)}{\pi_{\mathrm{ref}}\left(y_w \mid x\right)}-\beta \log \frac{\pi_\theta\left(y_l \mid x\right)}{\pi_{\mathrm{ref}}\left(y_l \mid x\right)}\right)\right]
$$

where,
- $\pi_\theta$ policy model (LLM to optimize)
- $ \pi_{\mathrm{ref}}$ reference model (original LLM before optimization)
- $\mathbb{E}$ is expected value
- $\sigma$ logistsic sigmoid function
- $\beta$ hyperparameter to control the divergence between the $\pi_{\theta}$ and $\pi_{ref}$

In [91]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [92]:
dataset = load_dataset("jondurbin/py-dpo-v0.1", split="train")

In [93]:
dataset = dataset.train_test_split(test_size=0.1, seed=42)

In [94]:
dataset

DatasetDict({
    train: Dataset({
        features: ['prompt', 'chosen', 'rejected', 'id'],
        num_rows: 8519
    })
    test: Dataset({
        features: ['prompt', 'chosen', 'rejected', 'id'],
        num_rows: 947
    })
})

In [95]:
print("Prompt: ")
print(dataset["train"][0]["prompt"])
print("\n\nChosen: ")
print(dataset["train"][0]["chosen"])
print("\n\nRejected: ")
print(dataset["train"][0]["rejected"])

Prompt: 
What are some efficient algorithms in Python for finding the intersection between two sets of geographic coordinates within a given radius?


Chosen: 
One possible algorithm for finding the intersection between two sets of geographic coordinates within a given radius in Python is as follows:

1. Define a function that calculates the great circle distance between two points on the Earth's surface. This can be done using the Haversine formula, which takes into account the Earth's radius and the latitude and longitude of the two points.

2. Define a function that takes two sets of geographic coordinates (as lists of latitude and longitude pairs) and a radius as input, and returns the intersection between the two sets within the given radius. This can be done by looping over all pairs of points from the two sets and checking whether the great circle distance between them is less than the given radius. If so, add the pair to a list of intersecting points.

3. Use this function to f

In [96]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [97]:
def format_input(entry):
    instruction_text = (
        f"Below is an instruction that describes a task. "
        f"Write a response that appropriately completes the request."
        f"\n\n### Instruction:\n{entry['prompt']}"
    )
    return instruction_text

In [98]:
def tokenize_fn(entry: dict):

    prompt = format_input(entry)
    chosen_response = entry["chosen"]
    rejected_response = entry["rejected"]

    chosen_full_text = f"{prompt}\n\n### Response:\n{chosen_response}"
    rejected_full_text = f"{prompt}\n\n### Response:\n{rejected_response}"

    prompt_tokens = tokenizer(chosen_full_text, truncation=True, padding=False, max_length=1024, return_attention_mask=False)["input_ids"]
    chosen_full_tokens = tokenizer(chosen_full_text, truncation=True, padding=False, max_length=1024, return_attention_mask=False)["input_ids"]
    rejected_full_tokens = tokenizer(rejected_full_text, truncation=True, padding=False, max_length=1024, return_attention_mask=False)["input_ids"]

    return { "prompt": prompt_tokens, "chosen": chosen_full_tokens, "rejected": rejected_full_tokens }

In [99]:
tokenized_dataset = dataset.map(tokenize_fn, remove_columns=["id"])

Map:   0%|          | 0/947 [00:00<?, ? examples/s]

Map: 100%|██████████| 947/947 [00:04<00:00, 221.94 examples/s]


In [100]:
print("Prompt: ")
print(tokenized_dataset["train"][0]["prompt"])
print("\n\nChosen: ")
print(tokenized_dataset["train"][0]["chosen"])
print("\n\nRejected: ")
print(tokenized_dataset["train"][0]["rejected"])

Prompt: 
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 389, 617, 6942, 16113, 287, 11361, 329, 4917, 262, 16246, 1022, 734, 5621, 286, 22987, 22715, 1626, 257, 1813, 16874, 30, 198, 198, 21017, 18261, 25, 198, 3198, 1744, 11862, 329, 4917, 262, 16246, 1022, 734, 5621, 286, 22987, 22715, 1626, 257, 1813, 16874, 287, 11361, 318, 355, 5679, 25, 198, 198, 16, 13, 2896, 500, 257, 2163, 326, 43707, 262, 1049, 9197, 5253, 1022, 734, 2173, 319, 262, 3668, 338, 4417, 13, 770, 460, 307, 1760, 1262, 262, 9398, 690, 500, 10451, 11, 543, 2753, 656, 1848, 262, 3668, 338, 16874, 290, 262, 32477, 290, 890, 3984, 286, 262, 734, 2173, 13, 198, 198, 17, 13, 2896, 500, 257, 2163, 326, 2753, 734, 5621, 286, 22987, 22715, 357, 292, 8341, 286, 32477, 290, 890, 3984, 14729, 8, 290, 257, 16874, 355, 5128, 11, 290, 5860, 262, 16246, 1022, 262, 734, 5621, 1626, 262, 1813, 16874, 13, 770, 460, 307, 1760, 416, 9052, 278

In [101]:
def collate_fn(
    batch,
    pad_token_id = 50256,
    context_length = None,
    mask_prompt_tokens = True,
    device = "cpu",
):

    batch_data = {
        "prompt": [],
        "chosen": [],
        "rejected": [],
        "rejected_mask": [],
        "chosen_mask": []

    }

    # dynamic padding
    max_length_common = 0
    for key in ["chosen", "rejected"]:
        current_max = max(len(item[key])+1 for item in batch)
        max_length_common = max(max_length_common, current_max)

    for item in batch:
        prompt = torch.tensor(item["prompt"])
        batch_data["prompt"].append(prompt)

        for key in ["chosen", "rejected"]:

            # adjust padding according to the common maximum length
            sequence = item[key]
            padded = sequence + [pad_token_id] * (max_length_common - len(sequence))
            mask = torch.ones(len(padded)).bool()

            # set mask for padding tokens to false after sequence ends
            mask[len(sequence):] = False

            # set mask for all prompt tokens to false
            # +2 sets the 2 newline ("\n") tokens before "### Response" to false
            if mask_prompt_tokens:
                mask[:prompt.shape[0]+2] = False

            batch_data[key].append(torch.tensor(padded))
            batch_data[f"{key}_mask"].append(mask)

    for key in ["chosen", "rejected", "chosen_mask", "rejected_mask"]:
        # stack all sequences into a tensor for the given key
        tensor_stack = torch.stack(batch_data[key])

        # truncate to maximum sequence length
        if context_length is not None:
            tensor_stack = tensor_stack[:, :context_length]

        batch_data[key] = tensor_stack.to(device)

    return batch_data

In [102]:
custom_collate_fn = partial(
    collate_fn,
    device=device,
    mask_prompt_tokens=True,
    context_length=1024 # model context length
)

In [103]:
train_loader = DataLoader(
    tokenized_dataset["train"],
    batch_size=8,
    shuffle=True,
    collate_fn=custom_collate_fn,
    num_workers=4,
)

test_loader = DataLoader(
    tokenized_dataset["test"],
    batch_size=8,
    collate_fn=custom_collate_fn,
    num_workers=4,
)

In [104]:
print("Train loader:")
for idx, batch in enumerate(train_loader):
    print(
        batch["chosen"].shape,
        batch["rejected"].shape,
    )
    if idx>10:
        break

print("\nTest loader:")
for idx, batch in enumerate(test_loader):
    print(
        batch["chosen"].shape,
        batch["rejected"].shape,
    )
    if idx>10:
        break

Train loader:


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/osvathm/ai-notebooks/venv/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/osvathm/ai-notebooks/venv/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
  File "/tmp/ipykernel_1330621/285878571.py", line 54, in collate_fn
    batch_data[key] = tensor_stack.to(device)
  File "/home/osvathm/ai-notebooks/venv/lib/python3.8/site-packages/torch/cuda/__init__.py", line 300, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method


In [None]:
# Model

In [None]:
def compute_logprobs(
    logits, # (batch_size, num_tokens, vocab_size)
    labels, # (batch_size, num_tokens)
    selection_mask=None # (batch_size, num_tokens)
):

    # labels are the inputs shifted by one
    #   i.e. the model predicts the next token in the sequence,
    #   so the labels are the input tokens shifted to the left
    #   and the first token is ignored (because there's no previous token to predict)
    labels = labels[:, 1:].clone()

    # adjust logits to match the labels num_tokens
    #   i.e. truncate the logits by removing the last token's logits
    #   because there's no corresponding target (label) to predict for the last token.
    logits = logits[:, :-1, :]

    log_probs = F.log_softmax(logits, dim=-1)

    # select log probability the model assigned to the correct token at each position
    selected_log_probs = torch.gather(
        input=log_probs,
        dim=-1,
        index=labels.unsqueeze(-1)
    ).squeeze(-1)

    print("Log probs:", selected_log_probs)
    print("Log probs shape:", selected_log_probs.shape)

    if selection_mask is not None:
        mask = selection_mask[:, 1:].clone()

        print("Mask:", mask)
        print("Mask shape:", mask.shape)

        # apply the mask to filter out padding tokens
        #    i.e. set the log probabilities of padding tokens to 0
        selected_log_probs = selected_log_probs * mask

        print("selected Log probs:", selected_log_probs)
        print("selected Log probs shape:", selected_log_probs.shape)

        # Calculate the average log probability excluding padding tokens
        # This averages over the tokens, so the shape is (batch_size, num_tokens)
        avg_log_prob = selected_log_probs.sum(-1) / mask.sum(-1)

        return avg_log_prob

    else:
        return selected_log_probs.mean(-1)

In [36]:
# (batch_size, num_tokens, vocab_size)
logits = torch.tensor(
    [
        [
            [2.0, 0.1, 0.3, -0.4, 1.2],  # logit for token 1 in batch 1
            [1.2, -0.5, 0.9, 0.2, 0.3],  # logit for token 2 in batch 1
            [-0.3, 1.4, 0.2, 0.5, -0.2],  # logit for token 3 in batch 1
            [0.6, 0.1, -0.3, 0.4, 0.7]    # logit for token 4 in batch 1
        ],
        [
            [-0.2, 0.7, 1.1, 0.3, 0.4],  # logit for token 1 in batch 2
            [1.5, -0.8, 0.6, 0.2, -0.1],  # logit for token 2 in batch 2
            [0.1, 0.9, 0.4, -0.6, 1.2],  # logit for token 3 in batch 2
            [-0.3, 0.2, 0.4, 0.6, -0.2] # logit for token 4 in batch 2
        ]
    ]
)

# (batch_size, num_tokens)
labels = torch.tensor(
    [
        [1, 2, 3, 4],  # True labels for sequence 1
        [1, 2, 3, 0]   # True labels for sequence 2 (0 is padding)
    ]
)

# (batch_size, num_tokens)
selection_mask = torch.tensor(
    [
        [1, 1, 1, 1],   # No padding tokens in sequence 1
        [1, 1, 1, 0]    # Padding token at the last position in sequence 2
    ]
)

avg_log_probs = compute_logprobs(logits, labels, selection_mask)

print("Average Log Probabilities:", avg_log_probs)

Log probs: tensor([[-2.3272, -1.9925, -2.3383],
        [-1.0608, -1.9837, -2.0889]])
Log probs shape: torch.Size([2, 3])
Mask: tensor([[1, 1, 1],
        [1, 1, 0]])
Mask shape: torch.Size([2, 3])
selected Log probs: tensor([[-2.3272, -1.9925, -2.3383],
        [-1.0608, -1.9837, -0.0000]])
selected Log probs shape: torch.Size([2, 3])
Average Log Probabilities: tensor([-2.2193, -1.5223])


In [29]:
import torch

log_probs = torch.tensor([
    [  # batch 0
        [-1.0, -2.0, -3.0],  # timestep 0
        [-0.1, -0.2, -2.5]   # timestep 1
    ]
])  # shape: (1, 2, 3)

labels = torch.tensor([
    [0, 2]
])  # shape: (1, 2)


print(log_probs.shape)

labels = labels.unsqueeze(-1)  # shape: (1, 2, 1)

# print(labels)
print(labels.shape)
print(labels[0,1,0])

selected_log_probs = torch.gather(log_probs, dim=-1, index=labels)  # shape: (1, 2, 1)

print(selected_log_probs)
print(selected_log_probs.shape)

selected_log_probs = selected_log_probs.squeeze(-1)  # shape: (1, 2)

print(selected_log_probs)
print(selected_log_probs.shape)

torch.Size([1, 2, 3])
torch.Size([1, 2, 1])
tensor(2)
tensor([[[-1.0000],
         [-2.5000]]])
torch.Size([1, 2, 1])
tensor([[-1.0000, -2.5000]])
torch.Size([1, 2])


In [None]:
def compute_dpo_loss(
      model_chosen_logprobs,
      model_rejected_logprobs,
      reference_chosen_logprobs,
      reference_rejected_logprobs,
      beta=0.1,
    ):

    model_logratios = model_chosen_logprobs - model_rejected_logprobs
    reference_logratios = reference_chosen_logprobs - reference_rejected_logprobs
    logits = model_logratios - reference_logratios

    losses = -F.logsigmoid(beta * logits)

    # Optional values to track progress during training
    chosen_rewards = (model_chosen_logprobs - reference_chosen_logprobs).detach()
    rejected_rewards = (model_rejected_logprobs - reference_rejected_logprobs).detach()

    # .mean() to average over the samples in the batch
    return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()