In [None]:
from importlib.metadata import version

pkgs = [
    "tiktoken",
    "torch",
]
for p in pkgs:
    print(f"{p} version: {version(p)}")

In [None]:
import json

file_path = "instruction-data-with-preference.json"
with open(file_path, 'r', encoding='utf-8') as f:
    data = json.load(f)

len(data)

In [None]:
import pprint

pprint.pp(data[50])

In [None]:
pprint.pp(data[999])

In [None]:
pprint.pp(data[900])

In [6]:
def format_input(entry):
    instruction_text = (
        f"Below is an instruction that describes a task. "
        f"Write a response that approximately completes the request."
        f"\n\n### Instruction:\n{entry['instruction']}"
    )
    input_text = f"\n\n### Input:\n{entry['input']}" if entry['input'] else ""
    return instruction_text + input_text

In [None]:
model_input = format_input(data[50])
print(model_input)

In [None]:
desired_response = f"### Response: \n{data[50]['chosen']}"
print(desired_response)

In [None]:
possible_response = f"### Response: \n{data[50]['rejected']}"
print(possible_response)

In [None]:
response_format = lambda entry: f"### Response: \n{entry['chosen']}"
print(response_format(data[50]))

In [11]:
train_portion = int(len(data) * 0.85)
test_portion = int(len(data) * 0.1)
val_portion = len(data) - train_portion - test_portion

train_data = data[:train_portion]
test_data = data[train_portion: train_portion + test_portion]
val_data = data[train_portion + test_portion:]

In [None]:
len(train_data), len(test_data), len(val_data)

In [13]:
import torch
from torch.utils.data import Dataset

class PreferenceDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data

        self.encoded_texts = []
        for entry in data:
            prompt = format_input(entry)
            rejected_response = entry['rejected']
            chosen_response = entry['chosen']
            chosen_full_text = f"{prompt}\n\n###Response:\n{chosen_response}"
            rejected_full_text = f"{prompt}\n\n###Response:\n{rejected_response}"
            
            prompt_tokens = tokenizer.encode(prompt)
            chosen_full_tokens = tokenizer.encode(chosen_full_text)
            rejected_full_tokens = tokenizer.encode(rejected_full_text)

            self.encoded_texts.append({
                'prompt': prompt_tokens,
                'chosen': chosen_full_tokens,
                'rejected': rejected_full_tokens,
            })
    
    def __getitem__(self, index):
        return self.encoded_texts[index]
    
    def __len__(self):
        return len(self.encoded_texts)


In [None]:
a = torch.ones([10]); b = torch.zeros([10])
c = [a, b]
d = torch.stack(c); d.shape

In [15]:
def custom_collate_fn(
        batch,
        pad_token_id=50256,
        allowed_max_length=None,
        mask_prompt_tokens=True,
        device='cpu'
):
    batch_data = {
        'prompt': [],
        'chosen': [],
        'rejected': [],
        'rejected_mask': [],
        'chosen_mask': []
    }

    max_length_common = 0
    if batch:
        for key in ['chosen', 'rejected']:
            # why adding +1 here? possibly end of sentence token
            current_max = max(len(item[key]) + 1 for item in batch)
            max_length_common = max(max_length_common, current_max)
    
    for item in batch:
        prompt = torch.tensor(item['prompt'])
        batch_data['prompt'].append(prompt)
        for key in ['chosen', 'rejected']:
            sequence = item[key]
            padded = sequence + [pad_token_id] * (max_length_common - len(sequence))
            mask = torch.ones(len(padded)).bool()
            
            # set mask for padding tokens to be False
            mask[len(sequence):] = False
            
            # +2 sets the new 2 newline tokens before ### Response to False
            # Set mask for input tokens to be False
            if mask_prompt_tokens:
                mask[:prompt.shape[0]+2] = False
            
            batch_data[key].append(torch.tensor(padded))
            batch_data[f"{key}_mask"].append(mask)
    
    # Process batch data
    for key in ['chosen', 'rejected', 'chosen_mask', 'rejected_mask']:
        # [B, max_length_common]
        tensor_stack = torch.stack(batch_data[key])
        if allowed_max_length is not None:
            tensor_stack = tensor_stack[:, :allowed_max_length]
        
        batch_data[key] = tensor_stack.to(device)
    
    return batch_data


In [None]:
from functools import partial
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

customized_collate_fn = partial(
    custom_collate_fn,
    device=device,
    mask_prompt_tokens=True,
    allowed_max_length=1024,
)

In [None]:
data??

In [None]:
example_data = data[:2]
for i in example_data:
    pprint.pp(i)

In [19]:
import tiktoken
from torch.utils.data import DataLoader
tokenizer = tiktoken.get_encoding('gpt2')

example_dataset = PreferenceDataset(example_data, tokenizer)
example_dataloader = DataLoader(
    example_dataset,
    batch_size=2,
    collate_fn=customized_collate_fn,
    shuffle=False
)

In [None]:
batch = next(iter(example_dataloader))
batch.keys()

In [None]:
batch['prompt'][0].shape, batch['prompt'][1].shape

In [None]:
batch['chosen'].shape

In [None]:
batch['rejected']

In [24]:
def decode_tokens_from_batch(token_ids, tokenizer):
    ids = token_ids.flatten().tolist()
    return tokenizer.decode(ids)

In [None]:
text = decode_tokens_from_batch(
    token_ids=batch['prompt'][0],
    tokenizer=tokenizer
)
print(text)

In [None]:
text = decode_tokens_from_batch(
    token_ids=batch['rejected'][0],
    tokenizer=tokenizer
)
print(text)

In [None]:
batch['prompt'][0].shape

In [None]:
batch['chosen_mask']

In [None]:
text = decode_tokens_from_batch(
    token_ids=batch['rejected'][0][batch['rejected_mask'][0]],
    tokenizer=tokenizer
)
print(text)

In [None]:
text = decode_tokens_from_batch(
    token_ids=batch['chosen'][0][batch['chosen_mask'][0]],
    tokenizer=tokenizer
)
print(text)

##### mask is used to ignore prompt and padding tokens while computing DPO loss.

In [31]:
from torch.utils.data import DataLoader
num_workers = 0
batch_size = 8

torch.manual_seed(123)
train_dataset = PreferenceDataset(train_data, tokenizer)
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers
)

In [33]:
val_dataset = PreferenceDataset(val_data, tokenizer)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
)

In [34]:
test_dataset = PreferenceDataset(test_data, tokenizer)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

In [None]:
for batch in train_loader:
    print(batch['chosen'].shape, batch['rejected'].shape)
    break