In [None]:
from importlib.metadata import version

pkgs = [
    "tiktoken",
    "torch",
]
for p in pkgs:
    print(f"{p} version: {version(p)}")

In [None]:
import gc
import torch
gc.collect()
torch.cuda.empty_cache()

In [None]:
import json

file_path = "/home/htkumar/llms/rasbt_llms_from_scratch/instruction-data-with-preference.json"
with open(file_path, 'r', encoding='utf-8') as f:
    data = json.load(f)

len(data)

In [None]:
import pprint

pprint.pp(data[50])

In [None]:
pprint.pp(data[999])

In [None]:
pprint.pp(data[900])

In [None]:
def format_input(entry):
    instruction_text = (
        f"Below is an instruction that describes a task. "
        f"Write a response that approximately completes the request."
        f"\n\n### Instruction:\n{entry['instruction']}"
    )
    input_text = f"\n\n### Input:\n{entry['input']}" if entry['input'] else ""
    return instruction_text + input_text

In [None]:
model_input = format_input(data[50])
print(model_input)

In [None]:
desired_response = f"### Response: \n{data[50]['chosen']}"
print(desired_response)

In [None]:
possible_response = f"### Response: \n{data[50]['rejected']}"
print(possible_response)

In [None]:
response_format = lambda entry: f"### Response: \n{entry['chosen']}"
print(response_format(data[50]))

In [None]:
train_portion = int(len(data) * 0.85)
test_portion = int(len(data) * 0.1)
val_portion = len(data) - train_portion - test_portion

train_data = data[:train_portion]
test_data = data[train_portion: train_portion + test_portion]
val_data = data[train_portion + test_portion:]

In [None]:
len(train_data), len(test_data), len(val_data)

In [None]:
import torch
from torch.utils.data import Dataset

class PreferenceDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data

        self.encoded_texts = []
        for entry in data:
            prompt = format_input(entry)
            rejected_response = entry['rejected']
            chosen_response = entry['chosen']
            chosen_full_text = f"{prompt}\n\n###Response:\n{chosen_response}"
            rejected_full_text = f"{prompt}\n\n###Response:\n{rejected_response}"

            prompt_tokens = tokenizer.encode(prompt)
            chosen_full_tokens = tokenizer.encode(chosen_full_text)
            rejected_full_tokens = tokenizer.encode(rejected_full_text)

            self.encoded_texts.append({
                'prompt': prompt_tokens,
                'chosen': chosen_full_tokens,
                'rejected': rejected_full_tokens,
            })

    def __getitem__(self, index):
        return self.encoded_texts[index]

    def __len__(self):
        return len(self.encoded_texts)


In [None]:
a = torch.ones([10]); b = torch.zeros([10])
c = [a, b]
d = torch.stack(c); d.shape

In [None]:
def custom_collate_fn(
        batch,
        pad_token_id=50256,
        allowed_max_length=None,
        mask_prompt_tokens=True,
        device='cpu'
):
    batch_data = {
        'prompt': [],
        'chosen': [],
        'rejected': [],
        'rejected_mask': [],
        'chosen_mask': []
    }

    max_length_common = 0
    if batch:
        for key in ['chosen', 'rejected']:
            # why adding +1 here? possibly end of sentence token
            current_max = max(len(item[key]) + 1 for item in batch)
            max_length_common = max(max_length_common, current_max)

    for item in batch:
        prompt = torch.tensor(item['prompt'])
        batch_data['prompt'].append(prompt)
        for key in ['chosen', 'rejected']:
            sequence = item[key]
            padded = sequence + [pad_token_id] * (max_length_common - len(sequence))
            mask = torch.ones(len(padded)).bool()

            # set mask for padding tokens to be False
            mask[len(sequence):] = False

            # +2 sets the new 2 newline tokens before ### Response to False
            # Set mask for input tokens to be False
            if mask_prompt_tokens:
                mask[:prompt.shape[0]+2] = False

            batch_data[key].append(torch.tensor(padded))
            batch_data[f"{key}_mask"].append(mask)

    # Process batch data
    for key in ['chosen', 'rejected', 'chosen_mask', 'rejected_mask']:
        # [B, max_length_common]
        tensor_stack = torch.stack(batch_data[key])
        if allowed_max_length is not None:
            tensor_stack = tensor_stack[:, :allowed_max_length]

        batch_data[key] = tensor_stack.to(device)

    return batch_data


In [None]:
from functools import partial
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
print(device)

customized_collate_fn = partial(
    custom_collate_fn,
    device=device,
    mask_prompt_tokens=True,
    allowed_max_length=1024,
)

In [None]:
data??

In [None]:
example_data = data[:2]
for i in example_data:
    pprint.pp(i)

In [None]:
import tiktoken
from torch.utils.data import DataLoader
tokenizer = tiktoken.get_encoding('gpt2')

example_dataset = PreferenceDataset(example_data, tokenizer)
example_dataloader = DataLoader(
    example_dataset,
    batch_size=2,
    collate_fn=customized_collate_fn,
    shuffle=False
)

In [None]:
batch = next(iter(example_dataloader))
batch.keys()

In [None]:
batch['prompt'][0].shape, batch['prompt'][1].shape

In [None]:
batch['chosen'].shape

In [None]:
batch['rejected']

In [None]:
def decode_tokens_from_batch(token_ids, tokenizer):
    ids = token_ids.flatten().tolist()
    return tokenizer.decode(ids)

In [None]:
text = decode_tokens_from_batch(
    token_ids=batch['prompt'][0],
    tokenizer=tokenizer
)
print(text)

In [None]:
text = decode_tokens_from_batch(
    token_ids=batch['rejected'][0],
    tokenizer=tokenizer
)
print(text)

In [None]:
batch['prompt'][0].shape

In [None]:
batch['chosen_mask']

In [None]:
text = decode_tokens_from_batch(
    token_ids=batch['rejected'][0][batch['rejected_mask'][0]],
    tokenizer=tokenizer
)
print(text)

In [None]:
text = decode_tokens_from_batch(
    token_ids=batch['chosen'][0][batch['chosen_mask'][0]],
    tokenizer=tokenizer
)
print(text)

##### mask is used to ignore prompt and padding tokens while computing DPO loss.

In [None]:
from torch.utils.data import DataLoader
num_workers = 0
batch_size = 8

torch.manual_seed(123)
train_dataset = PreferenceDataset(train_data, tokenizer)
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers
)

In [None]:
val_dataset = PreferenceDataset(val_data, tokenizer)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
)

In [None]:
test_dataset = PreferenceDataset(test_data, tokenizer)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

In [None]:
for batch in train_loader:
    print(batch['chosen'].shape, batch['rejected'].shape)
    break

In [None]:
# Load instruction finetuned model
import os
from pathlib import Path

finetuned_model_path = Path('/home/htkumar/llms/rasbt_llms_from_scratch/gpt2-medium-sft.pth')

In [None]:
from gpt_model import GPTModel, generate, text_to_token_ids, token_ids_to_text
from gpt_download import load_gpt2, BASE_CONFIG, model_configs

CHOOSE_MODEL = "gpt2-medium (355M)"
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
model = GPTModel(BASE_CONFIG)

In [None]:
device

In [None]:
model.load_state_dict(
    torch.load(
        '/home/htkumar/llms/rasbt_llms_from_scratch/gpt2-medium-sft.pth',
        map_location=device,
        weights_only=True
    )
)
model.eval();

In [None]:
input = format_input(data[2])
print(input)

In [None]:
token_ids = generate(
    model=model,
    idx=text_to_token_ids(input, tokenizer),
    max_new_tokens=35,
    context_size=BASE_CONFIG['context_length'],
    eos_id=50256,
)
generated_text = token_ids_to_text(token_ids, tokenizer)
print(generated_text)

In [None]:
policy_model = model
reference_model = GPTModel(BASE_CONFIG)
reference_model.load_state_dict(
    torch.load(
        '/home/htkumar/llms/rasbt_llms_from_scratch/gpt2-medium-sft.pth',
        map_location=device,
        weights_only=True
    )
)
reference_model.eval();
policy_model.to(device)
reference_model.to(device)

#### DPO loss function

In [None]:
import torch.nn.functional as F

def compute_dpo_loss(
    model_chosen_logprobs,
    model_rejected_logprobs,
    reference_chosen_logprobs,
    reference_rejected_logprobs,
    beta=0.1
):
    model_log_ratios = model_chosen_logprobs - model_rejected_logprobs
    reference_log_ratios = reference_chosen_logprobs - reference_rejected_logprobs
    logits = model_log_ratios - reference_log_ratios

    losses = -F.logsigmoid(beta * logits)
    chosen_rewards = (model_chosen_logprobs - reference_chosen_logprobs).detach()
    rejected_rewards = (model_rejected_logprobs - reference_rejected_logprobs).detach()

    return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()

In [None]:
def compute_dpo_loss_alt(
    model_chosen_logprobs,
    model_rejected_logprobs,
    reference_chosen_logprobs,
    reference_rejected_logprobs,
    beta=0.1
):
    chosen_logprobs = model_chosen_logprobs - reference_chosen_logprobs
    rejected_logprobs = model_rejected_logprobs - reference_rejected_logprobs

    logits = chosen_logprobs - rejected_logprobs

    losses = -F.logsigmoid(beta * logits)
    chosen_rewards = (model_chosen_logprobs - reference_chosen_logprobs).detach()
    rejected_rewards = (model_rejected_logprobs - reference_rejected_logprobs).detach()

    return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()

In [None]:
a = torch.tensor([1., 2., 3.])
torch.log(F.softmax(a))

In [None]:
F.log_softmax(a)

In [None]:
# Sample data
logits = torch.tensor(
    [[2.0, 1.0, 0.1],
    [0.5, 2.5, 0.3]]
)
targets = torch.tensor([0, 2])
logits.shape, targets.shape

In [None]:
log_softmax_logits = F.log_softmax(logits, dim=1)
selected_log_probs = torch.gather(
    input=log_softmax_logits,
    dim=1,
    index=targets.unsqueeze(1)
).squeeze(1)
print(log_softmax_logits)
print(selected_log_probs)
print(selected_log_probs.shape)
manual_loss = -selected_log_probs.mean()
print(manual_loss)

In [None]:
cross_entropy_loss = F.cross_entropy(logits, targets)
print(cross_entropy_loss)

In [None]:
t = torch.tensor(
  [[1., 2.,],
   [3., 4.]]
)
m = torch.tensor(
    [[1, 1, 1],
    [0, 1, 1]]
)

selected_nums = torch.gather(
    input=t,
    dim=1,
    index=m
)
selected_nums

In [None]:
log_probs = torch.tensor([
    [0.5, 0.3, 0.2],
    [1.0, 2.0, 3.0]
])
print(log_probs.mean(-1, keepdim=True))
mask = torch.tensor([
    [False, True, True],
    [False, True, True]
])
log_probs = log_probs * mask
(log_probs.sum(-1) / mask.sum(-1)).shape
log_probs.mean(-1)

In [None]:
# cross entropy is the minus of mean of log_probs of the correct label.

In [None]:
next(iter(train_loader)).keys()

In [None]:
test_batch = next(iter(train_loader))

In [None]:
def compute_logprobs(logits, labels, selection_mask=None):
    """
    logits is [B, num_tokens, vocab_size]
    labels is [B, num_tokens]
    selection_mask is [B, num_tokens]
    """
    logits = logits[:, :-1, :]
    log_probs = F.log_softmax(logits, dim=-1)
    labels = labels[:, 1:]

    # shape is [B, num_tokens-1] consisting of log_probs at every index.
    selected_log_probs = torch.gather(
        input=log_probs,
        dim=-1,
        index=labels.unsqueeze(-1)
    ).squeeze(-1)

    if selection_mask is not None:
        mask = selection_mask[:, 1:].clone()
        # Apply the mask to filter out padding tokens
        selected_log_probs = selected_log_probs * mask
        avg_log_prob = selected_log_probs.sum(-1) / mask.sum(-1)
        return avg_log_prob

    return selected_log_probs.mean(-1)

In [None]:
def compute_dpo_loss_batch(batch, policy_model, reference_model, beta):
    policy_model_chosen_logprobs = compute_logprobs(
        policy_model(batch['chosen']),
        batch['chosen'],
        batch['chosen_mask']
    )
    policy_model_rejected_logprobs = compute_logprobs(
        policy_model(batch['rejected']),
        batch['rejected'],
        batch['rejected_mask']
    )
    reference_model_chosen_logprobs = compute_logprobs(
        reference_model(batch['chosen']),
        batch['chosen'],
        batch['chosen_mask']
    )
    reference_model_rejected_logprobs = compute_logprobs(
        reference_model(batch['rejected']),
        batch['rejected'],
        batch['rejected_mask']
    )
    return compute_dpo_loss(
        policy_model_chosen_logprobs,
        policy_model_rejected_logprobs,
        reference_model_chosen_logprobs,
        reference_model_rejected_logprobs,
        beta
    )

In [None]:
compute_dpo_loss_batch(test_batch, policy_model, reference_model, 0.1)

In [None]:
def compute_dpo_loss_loader(data_loader, policy_model, reference_model, beta, num_batches=None):
    total_loss, total_chosen_rewards, total_rejected_rewards = 0., 0., 0.

    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))

    for i, batch in enumerate(data_loader):
        if i < num_batches:
            loss, chosen_rewards, rejected_rewards = compute_dpo_loss_batch(
                batch,
                policy_model,
                reference_model,
                beta
            )
            total_loss += loss.item()
            total_chosen_rewards += chosen_rewards.item()
            total_rejected_rewards += rejected_rewards.item()
        else:
            break

    return total_loss / num_batches, total_chosen_rewards / num_batches, total_rejected_rewards / num_batches

In [None]:
def evaluate_dpo_loss_loader(policy_model, reference_model, train_loader, val_loader, beta, eval_iter):
    # reference model has always been in eval model since creation.
    policy_model.eval()
    with torch.no_grad():
        train_loss, train_chosen_rewards, train_rejected_rewards = compute_dpo_loss_loader(
            data_loader=train_loader,
            policy_model=policy_model,
            reference_model=reference_model,
            beta=beta,
            num_batches=eval_iter
        )

        val_loss, val_chosen_rewards, val_rejected_rewards = compute_dpo_loss_loader(
            data_loader=val_loader,
            policy_model=policy_model,
            reference_model=reference_model,
            beta=beta,
            num_batches=eval_iter
        )

    res = {
        'train_loss': train_loss,
        'train_chosen_reward': train_chosen_rewards,
        'train_rejected_reward': train_rejected_rewards,
        'val_loss': val_loss,
        'val_chosen_reward': val_chosen_rewards,
        'val_rejected_reward': val_rejected_rewards
    }
    policy_model.train()
    return res

In [None]:
from gpt_model import generate_and_print_sample

In [None]:
def generate_model_output(
    model,
    tokenizer,
    data,
    device='cpu'
):
    input_text = format_input(data)

    token_ids = generate(
        model=policy_model,
        idx=text_to_token_ids(input_text, tokenizer).to(device),
        max_new_tokens=256,
        context_size=BASE_CONFIG['context_length'],
        eos_id=50256,
    )
    generated_text = token_ids_to_text(token_ids, tokenizer)
    response_text = generated_text[len(input_text):].replace("### Response:", "").strip()

    print(input_text)
    print(f"\nCorrect response:\n>> {data['output']}")
    print(f"\nModel response:\n>> {response_text}")
    print("\n----------------------------------------------\n")

In [None]:
generate_model_output(policy_model, tokenizer, val_data[2])

In [None]:
def train_model_dpo_simple(
    policy_model,
    reference_model,
    train_loader,
    val_loader,
    optimizer,
    num_epochs,
    beta,
    eval_freq,
    eval_iter,
    start_context,
    tokenizer
):
    tracking = {
        'train_losses': [],
        'train_chosen_rewards': [],
        'train_rejected_rewards': [],
        'val_losses': [],
        'val_chosen_rewards': [],
        'val_rejected_rewards': [],
        'tokens_seen': []
    }
    tokens_seen, global_step = 0, -1

    # Main training loop
    for epoch in range(num_epochs):
        policy_model.train()

        for batch_idx, batch in enumerate(train_loader):
            optimizer.zero_grad()

            loss, chosen_rewards, rejected_rewards = compute_dpo_loss_batch(
                batch=batch,
                policy_model=policy_model,
                reference_model=reference_model,
                beta=beta
            )

            loss.backward()
            optimizer.step() # update model parameters

            tokens_seen += batch['chosen'].numel()
            global_step += 1

            if global_step % eval_freq == 0:
                res = evaluate_dpo_loss_loader(
                    policy_model=policy_model,
                    reference_model=reference_model,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    beta=beta,
                    eval_iter=eval_iter
                )

                tracking['train_losses'].append(res['train_loss'])
                tracking["train_chosen_rewards"].append(res["train_chosen_reward"])
                tracking["train_rejected_rewards"].append(res["train_rejected_reward"])
                tracking["val_losses"].append(res["val_loss"])
                tracking["val_chosen_rewards"].append(res["val_chosen_reward"])
                tracking["val_rejected_rewards"].append(res["val_rejected_reward"])
                tracking["tokens_seen"].append(tokens_seen)

                train_reward_margin = res['train_chosen_reward'] - res['train_rejected_reward']
                val_reward_margin = res['val_chosen_reward'] - res['val_rejected_reward']

                print(
                    f"Ep: {epoch+1} (Step {global_step:06d})"
                    f"Train loss {res['train_loss']:.3f}, val loss: {res['val_loss']:.3f},"
                    f"Train reward margins {train_reward_margin:.3f} "
                    f"Val reward margin: {val_reward_margin:.3f}"
                )

                generate_model_output(
                    model=policy_model,
                    tokenizer=tokenizer,
                    data=start_context,
                    device=loss.device,
                )

    return tracking

In [None]:
torch.manual_seed(123)

res = evaluate_dpo_loss_loader(
    policy_model=policy_model,
    reference_model=reference_model,
    train_loader=train_loader,
    val_loader=val_loader,
    beta=0.1,
    eval_iter=5
)

print('Training loss: ', res['train_loss'])
print('val loss: ', res['val_loss'])
(res['train_chosen_reward'] - res['train_rejected_reward']), (res['val_chosen_reward'] - res['val_rejected_reward'])

In [None]:
for data in val_data[5:7]:
    input_text = format_input(data)

    token_ids = generate(
        model=policy_model,
        idx=text_to_token_ids(input_text, tokenizer).to(device),
        max_new_tokens=256,
        context_size=BASE_CONFIG['context_length'],
        eos_id=50256,
    )
    generated_text = token_ids_to_text(token_ids, tokenizer)
    response_text = generated_text[len(input_text):].replace("### Response:", "").strip()

    print(input_text)
    print(f"\nCorrect response:\n>> {data['output']}")
    print(f"\nModel response:\n>> {response_text}")
    print("\n----------------------------------------------\n")

In [None]:
import time
start_time = time.time()

optimizer = torch.optim.AdamW(policy_model.parameters(), lr=5e-6, weight_decay=0.01)
num_epochs = 1
tracking = train_model_dpo_simple(
    policy_model=policy_model,
    reference_model=reference_model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    num_epochs=num_epochs,
    beta=0.1,
    eval_freq=5,
    eval_iter=5,
    start_context=val_data[2],
    tokenizer=tokenizer,
)

end_time = time.time()
execution_time_mins = (end_time - start_time)/60
print(f"{execution_time_mins:.2f}")

In [None]:
len(tracking['train_losses'])

In [None]:
from gpt_model import plot_losses
epochs_tensor = torch.linspace(0, num_epochs, len(tracking['train_losses']))
plot_losses(
    epochs_seen=epochs_tensor,
    tokens_seen=tracking['tokens_seen'],
    train_losses=tracking['train_losses'],
    val_losses=tracking['val_losses'],
    label='loss'
)

In [None]:
train_rewards_margins = [i-j for i, j in zip(tracking['train_chosen_rewards'], tracking['train_rejected_rewards'])]
val_reward_margins = [i-j for i, j in zip(tracking['val_chosen_rewards'], tracking['val_rejected_rewards'])]

plot_losses(
    epochs_seen=epochs_tensor,
    tokens_seen=tracking['tokens_seen'],
    train_losses=train_rewards_margins,
    val_losses=val_reward_margins,
    label='loss'
)

In [None]:
for data in val_data[:10]:
    input_text = format_input(data)

    token_ids = generate(
        model=reference_model,
        idx=text_to_token_ids(input_text, tokenizer).to(device),
        max_new_tokens=256,
        context_size=BASE_CONFIG['context_length'],
        eos_id=50256,
    )
    generated_text = token_ids_to_text(token_ids, tokenizer)
    ref_response_text = generated_text[len(input_text):].replace("### Response:", "").strip()

    token_ids = generate(
        model=policy_model,
        idx=text_to_token_ids(input_text, tokenizer).to(device),
        max_new_tokens=256,
        context_size=BASE_CONFIG['context_length'],
        eos_id=50256,
    )
    generated_text = token_ids_to_text(token_ids, tokenizer)
    policy_response_text = generated_text[len(input_text):].replace("### Response:", "").strip()

    print(input_text)
    print(f"\nCorrect response:\n>> {data['output']}")
    print(f"\nReference Model response:\n>> {ref_response_text}")
    print(f"\nPolicy Model response:\n>> {policy_response_text}")
    print("\n----------------------------------------------\n")

In [None]:
for data in test_data[:5]:
    input_text = format_input(data)

    token_ids = generate(
        model=reference_model,
        idx=text_to_token_ids(input_text, tokenizer).to(device),
        max_new_tokens=256,
        context_size=BASE_CONFIG['context_length'],
        eos_id=50256,
    )
    generated_text = token_ids_to_text(token_ids, tokenizer)
    ref_response_text = generated_text[len(input_text):].replace("### Response:", "").strip()

    token_ids = generate(
        model=policy_model,
        idx=text_to_token_ids(input_text, tokenizer).to(device),
        max_new_tokens=256,
        context_size=BASE_CONFIG['context_length'],
        eos_id=50256,
    )
    generated_text = token_ids_to_text(token_ids, tokenizer)
    policy_response_text = generated_text[len(input_text):].replace("### Response:", "").strip()

    print(input_text)
    print(f"\nCorrect response:\n>> {data['output']}")
    print(f"\nReference Model response:\n>> {ref_response_text}")
    print(f"\nPolicy Model response:\n>> {policy_response_text}")
    print("\n----------------------------------------------\n")