In [1]:
import torch

from typing import Union, Literal, Tuple
from datasets import load_dataset
from peft import LoraConfig, AutoPeftModelForCausalLM
from transformers import (
    AutoModelForCausalLM,
    AutoConfig,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from trl import (
    DPOTrainer, 
    DataCollatorForCompletionOnlyLM
)

SFT_ADAPTER_DIRECTORY = "./open_llama_3b_v2_sft/"

2023-12-13 03:17:18.778257: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-13 03:17:18.828290: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX512F AVX512_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Model prep

In [2]:
# Model from Hugging Face hub
base_model = "openlm-research/open_llama_3b_v2"

In [3]:
# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [4]:
# Create model and quantization configs
config = AutoConfig.from_pretrained(base_model, trust_remote_code=True)
config.init_device = 'cuda:0' # For fast initialization directly on GPU!

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    # torch_dtype=torch.float16,
    bnb_4bit_use_double_quant=False,
)

In [5]:
# Load base model
model = AutoPeftModelForCausalLM.from_pretrained(
    SFT_ADAPTER_DIRECTORY,
    quantization_config=quant_config,
    trust_remote_code=True,
    is_trainable=True,
)
model.config.use_cache = False
model.config.pretraining_tp = 1

In [6]:
# Load reference model
model_ref = AutoPeftModelForCausalLM.from_pretrained(
    SFT_ADAPTER_DIRECTORY,
    quantization_config=quant_config,
    trust_remote_code=True,
    is_trainable=False,
)
model_ref.config.use_cache = False
model_ref.config.pretraining_tp = 1

## Dataset prep

In [7]:
dataset_name = "samlhuillier/sql-create-context-spider-intersect"

In [8]:
response_template = "\n-- Answer:\n"

In [9]:
def print_tokens_with_ids(txt):
    tokens = tokenizer.tokenize(txt, add_special_tokens=False)
    token_ids = tokenizer.encode(txt, add_special_tokens=False)
    print(list(zip(tokens, token_ids)))

In [10]:
print_tokens_with_ids(response_template)

[('▁', 29500), ('<0x0A>', 13), ('--', 559), ('▁Answer', 13910), (':', 29537), ('<0x0A>', 13)]


In [11]:
print_tokens_with_ids(f"-- Question: HI{response_template}")

[('▁--', 1472), ('▁Question', 10706), (':', 29537), ('▁HI', 27003), ('<0x0A>', 13), ('--', 559), ('▁Answer', 13910), (':', 29537), ('<0x0A>', 13)]


In [12]:
def format_prompt(example) -> Tuple[str, str]:
    return f"{example['context']} \n-- Question: {example['question']}{response_template}", example['answer']

In [13]:
def mutate(response, num_tokens=1):
    """ Change `num_tokens` to a random token in the vocabulary """
    # Encode the string
    tokens = tokenizer.encode(response, add_special_tokens=False)

    # Select `num_tokens` mutation indices
    if tokenizer.decode(tokens[1:]) == tokenizer.decode(tokens):
        # this means an additional prefix token was added
        mutation_indices = torch.randperm(len(tokens) - 1) + 1
    else:
        mutation_indices = torch.randperm(len(tokens))
    mutation_indices = mutation_indices[:num_tokens]

    # Mutate those indices
    for idx in mutation_indices:
        tokens[idx] = torch.randint(tokenizer.vocab_size, (1,)).item()
    
    return tokenizer.decode(tokens)

In [14]:
mutate("-- Hello, this is a long sentence to demonstrate string mutation", num_tokens=2)

'-- Hello, this is a long sentence RodSignature string mutation'

In [15]:
def get_dataset(split: Literal["train", "validation"] = "train", toks_to_mutate=1):
    """Load the dataset from Hugging Face and on-the-fly do (1) convert it to the necessary format and (2) impose token mutations.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }
    """
    dataset = load_dataset(dataset_name, split=split)
    original_columns = dataset.column_names

    tokens_to_mutate = toks_to_mutate if split == "train" else 0
    
    def batched_mutate(examples):
        out = {
            "prompt" : [],
            "chosen" : [],
            "rejected" : []
        }
        for question, ctx, ans in zip(examples["question"], examples["context"], examples["answer"]):
            prompt, resp = format_prompt({
                "question" : question,
                "context" : ctx,
                "answer" : ans
            })
            out["prompt"].append(prompt)
            out["chosen"].append(resp)
            out["rejected"].append(mutate(resp, num_tokens=tokens_to_mutate))
        return out

    dataset.set_transform(batched_mutate)
    return dataset

In [16]:
ds = get_dataset("train", toks_to_mutate=5)
ds

Dataset({
    features: ['answer', 'question', 'context', 'db_id'],
    num_rows: 3961
})

In [17]:
ds[0]

{'prompt': 'CREATE TABLE head (age INTEGER) \n-- Question: How many heads of the departments are older than 56 ?\n-- Answer:\n',
 'chosen': 'SELECT count(*) FROM head WHERE age  >  56',
 'rejected': 'SELECT count offset Ker Dolத WHERE age  > experience56'}

In [18]:
ds[0]

{'prompt': 'CREATE TABLE head (age INTEGER) \n-- Question: How many heads of the departments are older than 56 ?\n-- Answer:\n',
 'chosen': 'SELECT count(*) FROM head WHERE age  >  56',
 'rejected': 'SELECT count(*世 FROMicip AP ageading>  5FAULT'}

## Trainer Prep

In [19]:
# Initialize Trainer
trainer = DPOTrainer(
    model,
    model_ref, # The model with peft adapters turned off will be used as a reference model if not provided
    tokenizer=tokenizer,
    train_dataset=ds,
    beta=0.2, 
    max_length=2048,
    max_prompt_length=1500,
    args=TrainingArguments(
        output_dir="./dpo_results",
        optim="paged_adamw_32bit",

        max_grad_norm=0.3,
        warmup_ratio=0.03,
        
        learning_rate=2e-4,
        weight_decay=0.001,
        num_train_epochs=1,
        max_steps=-1,
        per_device_train_batch_size=2,
        
        gradient_accumulation_steps=1,
        save_steps=500,
        logging_steps=100,
        logging_first_step=True,
        
        fp16=False,
        bf16=False,

        remove_unused_columns=False,
        lr_scheduler_type="constant",
        report_to="tensorboard"
    ),
    peft_config=LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=64,
        bias="none",
        task_type="CAUSAL_LM",
    )
)



## Train

In [None]:
# Train model
trainer.train()

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss
1,0.6193


In [None]:
# Save trained model
new_model = "open_llama_3b_v2_sft_plus_dpo"
trainer.model.save_pretrained(new_model)