In [None]:
# Download files

import requests, zipfile, io

files_url = "https://ideami.com/llm_align"
print("Downloading files using Python")
# response = requests.get(files_url)
# zipfile.ZipFile(io.BytesIO(response.content)).extractall(".")

In [None]:
# Import libraries
import os, sys
import math
from tqdm import tqdm
from datetime import datetime
import ipdb
from typing import List, Dict, Union

# PyTorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# Import some HuggingFace Libraries
import transformers
from datasets import load_dataset, load_from_disk

# Performance (if you have cuda)
# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True

# torch.cuda.empty_cache()

# Optional, for debugging, if you want to view entire tensors
torch.set_printoptions(threshold=10000)

In [None]:
# Training parameters
batch_size = 1
epochs = 3 # after 3 epochs can possibly degrade or get worse
lr = 6e-5
lr_warmup_steps = 100 # increase learning rate until 100 steps
context = 1024
alpha = 0.5 # weighting for the ORPO odds ratio (sort of loss calculation variable)
prompt_max_size = 512 # limit for the prompt part of the interaction. 
# prompt: includes all the interaction except the last answer
# response: includes either the positive chosen answer or the negative rejected one
compile = False
dtype = torch.bfloat16
log_iters = 50

#HYPERPARAMETERS
dropout = 0.
grad_clip = 1.0
weight_decay = 0.0

# DEVICE
device = "cuda" if torch.cuda.is_available() else "mps" #iOS: mps, Windows: cpu
print("device: You will be using: ", device)

In [None]:
# LOGGING
project_name = "alignment"
wandb_log = True
wandb_project = project_name
# wandb_run_name = "aligntest-run"
wandb_run_name = "aligntest-run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S") # recommended

if wandb_log:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name)

In [None]:
dataset_path = "./data/orpo_dataset6"
dataset_name = "mlabonne/orpo-dpo-mix-40k"
tokenizer_path = "tokenizers/tok16384"
checkpoint_dir = "./models/"

# Tokenizing Dataset
# Load tokenizer in HuggingFace Format
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path)

# Set our interaction template
# tokenizer.chat_template = "{% for message in messages %}{% if message['role']=='user' %}\n{{ '<|user|> + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{%endif %}\n{% endfor %}"
tokenizer.chat_template = """{% for message in messages %}
{% if message['role'] == 'user' %}
<|user|>{{ message['content'] }}{{ eos_token }}
{% elif message['role'] == 'assistant' %}
<|assistant|>{{ message['content'] }}{{ eos_token }}
{% endif %}
{% if loop.last and add_generation_prompt %}
<|assistant|>
{% endif %}
{% endfor %}"""

# Make padding token equal to the end of sentence token (which has ID of 2 in our case)
tokenizer.pad_token = tokenizer.eos_token

if os.path.exists(dataset_path):
    dataset = load_from_disk(dataset_path)
else:
    print("Filtering and tokenizing datset")
    dataset = load_dataset(dataset_name, split="all")
    # Now we will tokenize it

    # Optional: Filter some of the entries # 37136 -> 36622
    dataset = dataset.filter(lambda r: r["source"] != "toxic-dpo-v0.2")

    # FILTER DATASET
    # Eliminate entries longer than 512(prompt_max_size). 
    # this is important because we want the prompt + answer to fit within the total context (1024)
    def filter_dataset(examples):
        prompt_length = tokenizer.apply_chat_template(examples['chosen'][:-1], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)

        if prompt_length < prompt_max_size: #512
            return True
        else:
            return False

    # Preprcess and tokenize data
    def preprocess_dataset(examples: Union[List,Dict]):
        # Take chosen field, eleminate last answer, apply chat template adding assistant prompt
        prompt = [tokenizer.apply_chat_template(item[:-1], tokenize=False, add_generation_prompt=True) for item in examples['chosen']]
        chosen = [tokenizer.apply_chat_template(item, tokenize=False) for item in examples['chosen']]
        rejected = [tokenizer.apply_chat_template(item, tokenize=False) for item in examples['rejected']]

        # Tokenize
        # HF Tokenizer Dict Format
        # Fields: ids, type_ids, tokens, offsets, attention_mask, special_token_mask, overflowing
        inputs = tokenizer(prompt, max_length=context, padding='max_length', truncation=True, return_tensors='pt')
        # debug: inputs.input_ids[0], inputs.attention_maks[0]
        pos_labels = tokenizer(chosen, max_length=context, padding='max_length', truncation=True, return_tensors='pt')
        neg_labels = tokenizer(rejected, max_length=context, padding='max_length', truncation=True, return_tensors='pt')

        inputs['positive_input_ids'] = pos_labels['input_ids']
        inputs['positive_attention_mask'] = pos_labels['attention_mask']
        
        inputs['negative_input_ids'] = neg_labels['input_ids']
        inputs['negative_attention_mask'] = neg_labels['attention_mask']

        # Prompt: inputs['input_ids'][0], inputs['attention_mask'][0]
        # Positive: inputs['positive_input_ids'][0], inputs['positive_attention_mask'][0]
        # Negative: inputs['negative_input_ids'][0], inputs['negative_attention_mask'][0]

        return inputs

    # Excluding prompts that are too long
    dataset = dataset.filter(filter_dataset)

    # Preprocess and tokenize dataset
    # If you have issues with multiprocessing, change num_proc = 1
    # For multiprocessing: num_proc=min(32, os.cpu_count())
    # by default sending batches of 1000
    dataset = dataset.map(preprocess_dataset, batched=True, num_proc=1, remove_columns=dataset.column_names)

    dataset.save_to_disk(dataset_path)

In [None]:
dataset[0]['positive_input_ids']
tokenizer.decode(dataset[0]['positive_input_ids'])

In [None]:
from datasets import Dataset, DatasetDict

# Split the data
if isinstance(dataset, Dataset):
    dataset = dataset.shuffle(42).train_test_split(test_size=0.05)
elif isinstance(dataset, DatasetDict):
    print("It is already a DatasetDict, skipping split.")
train_data = dataset['train']
# features: 'input_ids', 'attention_mask'
val_data = dataset['test']
# features: 'input_ids', 'attention_mask'

data_collator = transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Setup DataLoaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, collate_fn=data_collator, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False, collate_fn=data_collator, num_workers=0)

In [None]:
it = iter(train_loader)
batch = next(it)
#print(tokenizer.decode(batch['positive_input_ids'][0]))

In [None]:
# SETUP ARCHITECTURE

from llm import Llama, ModelArgs

checkpoint = torch.load(os.path.join(checkpoint_dir, "base_model.pt"), weights_only=False)
config = checkpoint.pop("config")

model_args = ModelArgs(
    dim=config.hidden_size,
    n_layers=config.num_hidden_layers,
    n_heads=config.num_attention_heads,
    n_kv_heads=config.num_key_value_heads,
    vocab_size=config.vocab_size,
    norm_eps=config.rms_norm_eps,
    rope_theta=config.rope_theta,
    max_seq_len=context,
    dropout=config.attention_dropout,
    hidden_dim=config.intermediate_size,
    attention_bias=config.attention_bias,
    mlp_bias=config.mlp_bias
)
# dim=768, n_layers=12, n_heads=12, vocab=16384, etc

model = Llama(model_args)
model.load_state_dict(checkpoint)
model=model.to(dtype)
model=model.to(device)
model.train()

if compile:
    print("[INFO] Compiling model")
    model = torch.compile(model)

print(sum(p.numel() for p in model.parameters()) / 1e6, " Million parameters")

In [None]:
# Optimizer

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95), eps=1e-8, fused =device =='mps', weight_decay=weight_decay)
# cuda, if it's Windows/Linux and using GPU not "mps". currently using it because the code is run in iOS

num_training_steps = len(train_loader) * epochs
print(f"num_training_steps: {num_training_steps}")

# Scheduler for lr: first 100 steps, we do a warmup in which we increase linearly the lr
# After warmup, we decrease it gradually following a cosine curve

def lr_lambda(current_step):
    if current_step < lr_warmup_steps:
        return float(current_step)/float(max(1, lr_warmup_steps))
    progress = float(current_step - lr_warmup_steps) / float(max(1, num_training_steps-lr_warmup_steps))
    return max(0.0, 0.5*(1.0 + math.cos(math.pi*float(0.5) * 2.0  * progress)))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)

In [None]:
def compute_logps(prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
    # conpute the PER TOKEN Log Probabilities

    mask = chosen_attention_mask[:,:-1] - prompt_attention_mask[:,1:]
    per_token_logps = torch.gather(logits[:,:-1,:].log_softmax(-1), dim=2, index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)

    return torch.mul(per_token_logps, mask.to(dtype)).sum(dim=1).to(dtype)/mask.sum(dim=1).to(dtype)

In [None]:
# TOY EXAMPLE TO UNDERSTAND THE COMPUTATION
p_mask = torch.tensor([1,1,1,1,0,0,0,0,0,0,0]) # 1s are the prompt
c_mask = torch.tensor([1,1,1,1,1,1,1,0,0,0,0]) # answer is three 1s
c_inputs = torch.tensor([2,4,3,2,4,2,1,0,0,0,0])
print(p_mask)
print(c_mask)
print(c_inputs)

mask = c_mask[:-1] - p_mask[1:]
print(mask)

logits = torch.tensor([
    [0.1583, 0.0794, 0.1967, 0.3643, 0.2013],
    [0.2517, 0.0463, 0.2702, 0.1700, 0.2617],
    [0.1405, 0.1943, 0.1371, 0.1557, 0.3724],
    [0.1266, 0.2257, 0.2330, 0.1872, 0.2275],
    [0.1685, 0.2091, 0.1680, 0.1649, 0.2895],
    [0.1993, 0.1869, 0.2283, 0.2176, 0.1679],
    [0.1359, 0.2634, 0.1817, 0.1952, 0.2237],
    [0.2431, 0.1394, 0.1615, 0.2876, 0.1684],
    [0.2019, 0.2145, 0.2046, 0.1186, 0.2604],
    [0.2340, 0.1782, 0.2505, 0.1342, 0.2031],
    [0.1667, 0.2784, 0.1309, 0.1384, 0.2856],
])

print(c_inputs[1:]) # to match the last token of prompt to the first token of answer
index=(mask * c_inputs[1:])
print(index)
print(index.shape)
print(logits.shape)

# Expand dimensions for correct gather shape
index_expanded = index.unsqueeze(1)
print(index_expanded)
print(index_expanded.shape)

# Gather the values at the specified indicies
gathered_values = torch.gather(logits[:-1,:], dim=1, index=index_expanded) # get rid of the last logits because there is nothing to predict
print(gathered_values)

# Squeeze to remove the unnecessary dimension
per_token_logps = gathered_values.squeeze(1)
print(per_token_logps)

result = torch.mul(per_token_logps, mask)
print(result)
f1 = result.sum(dim=0)
f2 = mask.sum(dim=0)
print(f1)
print(f2)
final = f1/f2
print(final)
# torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2, index=(mask * chosen_inputs[:, 1:])
                      


In [None]:
# ALIGNMENT TRAINING LOOP

try:
    for e in range(epochs):
        for i, batch in tqdm(enumerate(train_loader), total=len(train_loader), dynamic_ncols=True):
            optimizer.zero_grad(set_to_none=True)


            batch["positive_input_ids"] = batch["positive_input_ids"].to(device)
            batch["positive_attention_mask"] = batch["positive_attention_mask"].to(device)
            batch["negative_input_ids"] = batch["negative_input_ids"].to(device)
            batch["negative_attention_mask"] = batch["negative_attention_mask"].to(device)
            batch["attention_mask"] = batch["attention_mask"].to(device)


            neg_labels = batch["negative_input_ids"].clone()
            pos_labels = batch["positive_input_ids"].clone()


            # CALCULATING THE LOSS
            mask = batch["attention_mask"] * batch["positive_attention_mask"]
            # mask = batch['attention_mask'] will be the same since the attention is 1/0
            pos_labels = pos_labels * mask.logical_not()
            # put 0s where the prompt was, preserve the last answer, padding tokens have eos (2)

            pos_labels[pos_labels == 0] = tokenizer.pad_token_id # eos: 2
            pos_labels[pos_labels == tokenizer.pad_token_id] = - 100
            neg_labels[neg_labels == tokenizer.pad_token_id] = - 100

            outputs_pos, loss_pos = model(batch["positive_input_ids"], pos_labels)
            outputs_neg, loss_neg = model(batch["negative_input_ids"], neg_labels)

            # Calculate per token Log probabilities, essential to calculate the ORPO LOG ODDS RATIO

            pos_prob = compute_logps(
                prompt_attention_mask=batch['attention_mask'],
                chosen_inputs = batch["positive_input_ids"],
                chosen_attention_mask = batch['positive_attention_mask'],
                logits = outputs_pos
            )
            neg_prob = compute_logps(
                prompt_attention_mask=batch['attention_mask'],
                chosen_inputs = batch["negative_input_ids"],
                chosen_attention_mask = batch['negative_attention_mask'],
                logits = outputs_neg
            )

            # Calcualte Orpo Odds Ratio
            log_odds = (pos_prob - neg_prob) - torch.log(1-torch.exp(pos_prob)) - torch.log(1-torch.exp(neg_prob))
            sig_ratio = F.sigmoid(log_odds)
            ratio = torch.log(sig_ratio)

            # Calculate the final loss
            loss = torch.mean(loss_pos - (alpha*ratio).mean()).to(dtype=dtype)

            # Logging
            if i%log_iters == 0:
                print(f"Epochs [{e}/{epochs}] Step: [{i}/{len(train_loader)}], train loss: {loss.item():.3f}, Odds Ratio: {log_odds.mean().item():.3f}")

                if wandb_log:
                    wandb.log({
                        "train_loss":loss.item(),
                        "log_odds":log_odds.mean().item(),
                        "lr": scheduler.get_last_lr()[0],
                    },
                    step = (e*len(train_loader)+i))

                if torch.isnan(loss):
                    if wandb_log:
                        wandb.finish()
                    # torch.cuda.empty_cache()
                    sys.exit()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
            optimizer.step()
            scheduler.step()

        sd = model.state_dict()
        sd['config'] = config
        torch.save(sd, os.path.join(checkpoint_dir, f'{project_name}_{e+1}.pt'))

except KeyboardInterrupt:
    print("Training interrupted. Cleaning up..")

finally:
    # Release GPU Memory
    # torch.cuda.empty_cache()
    print("GPU memory released")
    # sys.exit(0)

if wandb_log:
    wandb.finish()
# torch.cuda.empty_cache()