In [1]:
from transformers import AutoTokenizer, TrainingArguments, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from datasets import load_dataset, concatenate_datasets, Dataset
from trl import DPOTrainer
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
import torch
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import yaml
from utils import get_logger
from accelerate import Accelerator
import bitsandbytes as bnb
import os
import random


  from .autonotebook import tqdm as notebook_tqdm


[2024-01-14 04:23:18,007] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [6]:
logger = get_logger("finetune", "info")

@dataclass
class DatasetConfig:
    path: str  # Path to the dataset
    split: str  # Dataset split (e.g., 'train', 'test')
    type: dict  # Additional configuration for the dataset
    name: str


@dataclass
class Config:
    # Training and validation file paths
    val_data_size: float  # Validation data size ratio

    # Model configuration
    model_name: str  # Name of the model
    model_dtype: Optional[
        str
    ] = None  # model datatype, only float16 or bfloat16 supported
    token: Optional[str] = None  # Authentication token, if required
    split_model: bool = False  # Whether to split the model

    # Model training parameters
    block_size: int = 128  # Size of the blocks used in the model
    lora_rank: int = 64  # LoRA rank
    lora_alpha: Optional[int] = None  # Alpha value for LoRA
    lora_dropout: float = 0.1  # Dropout rate for LoRA
    learning_rate: float = 1e-4  # Learning rate
    lr_scheduler_type: str = "constant"  # Type of learning rate scheduler
    warmup_steps: int = 10  # Number of warmup steps
    weight_decay: float = 0.05  # Weight decay factor
    output_dir: str = "./checkpoints"  # Directory to save model checkpoints
    log_steps: int = 10  # Frequency of logging steps
    eval_steps: int = 10  # Evaluation step frequency
    save_steps: int = 10  # Model saving step frequency
    epochs: float = 1  # Number of training epochs
    batch_size: int = 1  # Training batch size
    gradient_accumulation_steps: int = 1  # Gradient accumulation steps
    gradient_checkpointing: bool = False  # Enable gradient checkpointing
    trust_remote_code: bool = False  # Trust remote code flag
    save_limit: int = 1  # Limit for saving models
    optimizer: str = "adamw_torch"
    bf16: bool = False
    fp16: bool = False

    # SFTTrainer configuration
    packing: bool = False

    # Additional model configuration
    use_int4: bool = False  # Use int4 precision
    use_int8: bool = False  # Use int8 precision
    disable_lora: bool = False  # Disable LoRA
    disable_flash_attention: bool = False  # Disable flash attention
    all_linear: bool = False  # Use LoRA on all linear layers
    pad_token_id: Optional[int] = None  # End of sequence token ID
    add_eos_token: bool = False  # Add EOS token to tokenizer
    add_bos_token: bool = False  # Add BOS token to tokenizer
    add_pad_token: bool = False  # Add PAD token to tokenizer
    padding_side: Optional[str] = None  # Padding side for tokenizer
    # New field for special tokens
    special_tokens: Dict[str, str] = field(default_factory=lambda: {})
    custom_tokens: List[str] = field(default_factory=list)  # List of custom_tokens

    # Dataset handling
    completion_only: bool = False  # Only use completion loss
    wand_db_project: str = "trl_finetuning"  # Wandb project to use
    prepare_data_path: Optional[str] = None  # dataset cache folder
    datasets: List[DatasetConfig] = field(
        default_factory=list
    )  # List of dataset configurations
    chat_template: Optional[str] = None

In [7]:
def chatml_format(example):
    # Format system
    if len(example['system']) > 0:
        message = {"role": "system", "content": example['system']}
        system = tokenizer.apply_chat_template([message], tokenize=False)
    else:
        system = ""

    # Format instruction
    message = {"role": "user", "content": example['question']}
    prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)

    # Format chosen answer
    chosen = example['chosen'] + "<|im_end|>\n"

    # Format rejected answer
    rejected = example['rejected'] + "<|im_end|>\n"

    return {
        "prompt": system + prompt,
        "chosen": chosen,
        "rejected": rejected,
    }

def loaddata(config):
    if config.prepare_data_path and os.path.exists(config.prepare_data_path):
        logger.info("load datasets from disk")
        combined_dataset = Dataset.load_from_disk(config.prepare_data_path)
    else:
        logger.info("load datasets from hub")
        all_datasets = []
        for dataset_config in config.datasets:
            # Load dataset
            name = dataset_config["name"] if "name" in dataset_config else None
            path = dataset_config["path"]
            dataset = load_dataset(path, split=dataset_config["split"], name=name)
            
            # Save columns
            original_columns = dataset.column_names
            
            # Format dataset
            dataset = dataset.map(
                chatml_format,
                remove_columns=original_columns
            )
            print(dataset[0])
            all_datasets.append(dataset)
        combined_dataset = concatenate_datasets(all_datasets)

        logger.info("shuffle merged datasets")
        combined_dataset = combined_dataset.shuffle()
        if config.prepare_data_path:
            # Save combined dataset to disk
            combined_dataset.save_to_disk(config.prepare_data_path)
        
    # Split data
    split_dataset = combined_dataset.train_test_split(
        test_size=config.val_data_size,
        shuffle=True,
    )

    return split_dataset["train"], split_dataset["test"]

In [8]:
def load_config(config_file):
    with open(config_file, "r") as file:
        config_dict = yaml.safe_load(file)
    config_dict = {k: v for k, v in config_dict.items() if v is not None}
    return Config(**config_dict)

In [9]:
def prepare_tokenizer(config):
    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    if config.pad_token_id is not None:
        logger.info("Using pad token id %d", args.pad_token_id)
        tokenizer.pad_token_id = args.pad_token_id
        tokenizer.pad_token = tokenizer.convert_ids_to_tokens(args.pad_token_id)
    if config.padding_side is not None:
        tokenizer.padding_side = config.padding_side
    return tokenizer

In [10]:
SUPPORTED_FLASH_MODELS = ["llama", "mistral", "falcon", "mixtral", "opt"]

def get_model_config(args: Config):
    config_kwargs = {
        "trust_remote_code": True if args.trust_remote_code else None,
    }
    config = AutoConfig.from_pretrained(args.model_name, **config_kwargs)

    config.use_cache = False
    if not args.gradient_checkpointing:
        logger.info("Not using gradient checkpointing")
        config.gradient_checkpointing = False
    else:
        logger.info("Using gradient checkpointing")
        config.gradient_checkpointing = True

    return config

def find_all_linear_names(args, model, add_lm_head=True):
    cls = (
        bnb.nn.Linear4bit
        if args.use_int4
        else (bnb.nn.Linear8bitLt if args.use_int8 else torch.nn.Linear)
    )
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if add_lm_head and not "lm_head" in lora_module_names:
        logger.info("Adding lm_head to lora_module_names")
        lora_module_names.add("lm_head")

    return list(lora_module_names)

def prepare_model(args: Config):
    config = get_model_config(args)
    config_dict = config.to_dict()
    model_type = config_dict["model_type"]

    use_flash_attention = False

    if not args.disable_flash_attention and model_type not in SUPPORTED_FLASH_MODELS:
        logger.info(
            "Model is not llama, mistral, or falcon disabling flash attention..."
        )
    elif args.disable_flash_attention and model_type in SUPPORTED_FLASH_MODELS:
        logger.info(
            "Model is llama, mistral or falcon could be using flash attention..."
        )
    elif not args.disable_flash_attention:
        logger.info("Using flash attention...")
        use_flash_attention = True

    if args.split_model:
        logger.info("Splitting the model across all available devices...")
        kwargs = {"device_map": "auto"}
    else:
        kwargs = {"device_map": None}
    
    torch_dtype = torch.float32
    if args.model_dtype == "float16":
        torch_dtype = torch.float16
    elif args.model_dtype == "bfloat16":
        torch_dtype = torch.bfloat16

    if args.use_int4:
        logger.info("Using int4 quantization")
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch_dtype,
            bnb_4bit_use_double_quant=True,
        )
        if not args.split_model:
            device_index = Accelerator().process_index
            device_map = {"": device_index}
            kwargs["device_map"] = device_map
        args.use_int8 = False
    elif args.use_int8:
        logger.info("Using int8 quantization")
        bnb_config = BitsAndBytesConfig(
            load_in_8bit=True,
        )
    else:
        logger.info("Using no quantization")
        bnb_config = None
    
    # Model to fine-tune
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        quantization_config=bnb_config,
        trust_remote_code=args.trust_remote_code,
        torch_dtype=torch_dtype,
        config=config,
        use_flash_attention_2=use_flash_attention,
        **kwargs,
    )
    
    # Reference model
    ref_model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        quantization_config=bnb_config,
        trust_remote_code=args.trust_remote_code,
        torch_dtype=torch_dtype,
        config=config,
        use_flash_attention_2=use_flash_attention,
        **kwargs,
    )
    return model, ref_model

In [11]:
args = load_config("configs/mistral-dpo.yml")

In [9]:
tokenizer = prepare_tokenizer(args)
tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer

LlamaTokenizerFast(name_or_path='cognitivecomputations/dolphin-2.6-mistral-7b', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '<|im_end|>', 'unk_token': '<unk>', 'pad_token': '<|im_end|>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32000: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
}

In [10]:
os.environ["WANDB_PROJECT"] = args.wand_db_project
train_dataset, _ = loaddata(args)

Map: 100%|██████████| 12859/12859 [00:01<00:00, 7766.51 examples/s]

{'chosen': '[\n  ["AFC Ajax (amateurs)", "has ground", "Sportpark De Toekomst"],\n  ["Ajax Youth Academy", "plays at", "Sportpark De Toekomst"]\n]<|im_end|>\n', 'rejected': " Sure, I'd be happy to help! Here are the RDF triplets for the input sentence:\n\n[AFC Ajax (amateurs), hasGround, Sportpark De Toekomst]\n[Ajax Youth Academy, playsAt, Sportpark De Toekomst]\n\nExplanation:\n\n* AFC Ajax (amateurs) is the subject of the first triplet, and hasGround is the predicate that describes the relationship between AFC Ajax (amateurs) and Sportpark De Toekomst.\n* Ajax Youth Academy is the subject of the second triplet, and playsAt is the predicate that describes the relationship between Ajax Youth Academy and Sportpark De Toekomst.\n\nNote that there may be other possible RDF triplets that could be derived from the input sentence, but the above triplets capture the main relationships present in the sentence.<|im_end|>\n", 'prompt': "<|im_start|>user\nYou will be given a definition of a task




In [11]:
model, ref_model = prepare_model(args)

The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.
Loading checkpoint shards: 100%|██████████| 3/3 [00:57<00:00, 19.10s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.55s/it]


In [8]:
train_dataset[2]

{'chosen': 'Based on the given review, I would rate it as a 4 on the scale of 1-5. The reviewer expresses a positive sentiment by saying "I love it," which indicates a strong liking for the app. The quote "Where\'d all my free space go?" seems to be a rhetorical question highlighting that the app effectively shows the user where their free space has gone. The overall tone seems positive, but it lacks a deeper analysis of the app\'s features and functionality. A rating of 5 would typically be reserved for a more comprehensive and detailed review that gives users a better understanding of the app.<|im_end|>\n',
 'rejected': ' Sure, I\'d be happy to help! Based on the review provided, I would rate it as a 4 out of 5 in terms of favorability. Here\'s my breakdown of the review:\n\nPros:\n\n1. The user "loves" the app, which indicates a strong positive sentiment.\n2. The app provides an answer to a question that is important to the user (i.e., "Where\'d all my free space go?").\n3. The revi

In [12]:
target_modules = find_all_linear_names(args, model, add_lm_head=False)
target_modules

['q_proj', 'v_proj', 'up_proj', 'k_proj', 'down_proj', 'o_proj', 'gate_proj']

In [13]:
# Training arguments
training_args = TrainingArguments(
    do_train=True,
    output_dir=args.output_dir,
    save_strategy="steps",
    logging_strategy="steps",
    num_train_epochs=args.epochs,
    logging_steps=1,
    per_device_train_batch_size=args.batch_size,
    optim=args.optimizer,
    learning_rate=args.learning_rate,
    lr_scheduler_type=args.lr_scheduler_type,
    warmup_steps=args.warmup_steps,
    weight_decay=args.weight_decay,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    gradient_checkpointing=args.gradient_checkpointing,
    report_to="wandb",
    save_total_limit=args.save_limit,
    bf16=args.bf16,
    fp16=args.fp16,
    max_steps=1000,
)

peft_config = LoraConfig(
    r=args.lora_rank,
    lora_alpha=args.lora_alpha,
    lora_dropout=args.lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=target_modules
)

# Create DPO trainer
dpo_trainer = DPOTrainer(
    model,
    ref_model,
    args=training_args,
    peft_config=peft_config,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    beta=0.1,
    max_prompt_length=args.block_size,
    max_length=args.block_size*2,
    max_target_length=args.block_size,
)

Map: 100%|██████████| 12216/12216 [00:46<00:00, 265.35 examples/s]


In [14]:
dpo_trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjihengcu[0m. Use [1m`wandb login --relogin`[0m to force relogin


The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss
1,0.699
2,0.7242
3,0.6938
4,0.6808
5,0.7088
6,0.6521
7,0.6831
8,0.638
9,0.684
10,0.6585




TrainOutput(global_step=1000, training_loss=0.025301375399152313, metrics={'train_runtime': 5051.9564, 'train_samples_per_second': 0.396, 'train_steps_per_second': 0.198, 'total_flos': 0.0, 'train_loss': 0.025301375399152313, 'epoch': 0.16})

In [15]:
dpo_trainer.model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)

('./mistral_dpo/tokenizer_config.json',
 './mistral_dpo/special_tokens_map.json',
 './mistral_dpo/tokenizer.model',
 './mistral_dpo/added_tokens.json',
 './mistral_dpo/tokenizer.json')

In [16]:
import gc
# Flush memory
del dpo_trainer, model, ref_model
gc.collect()
torch.cuda.empty_cache()

In [12]:
from peft import PeftModel

args = load_config("configs/mistral-dpo.yml")

# Reload model in FP16 (instead of NF4)
base_model = AutoModelForCausalLM.from_pretrained(
    args.model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)

# Merge base model with the adapter

Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.09s/it]


In [14]:
prompt = """<|im_start|>system
You are an AI assistant. You will be given a task. You must generate a detailed and long answer.<|im_end|>
<|im_start|>user
Summarize this article in one sentence. The incident happened on Fife Street between 04:30 BST and 05:00 BST on Saturday. The police have appealed for witnesses to the crash. York Road is closed between its junctions with Skegoneill Avenue and Alexandra Park Avenue. Diversions are in place. Summary:<|im_end|>
<|im_start|>assistant
"""
model_input = tokenizer(prompt, return_tensors="pt").to("cuda")


In [15]:
with torch.no_grad():
    output = base_model.generate(**model_input, max_new_tokens=500, temperature=0.7, do_sample=True,
                            eos_token_id=tokenizer.eos_token_id,
                            pad_token_id=tokenizer.pad_token_id)[0]
    print(tokenizer.decode(output))



<s><|im_start|> system
You are an AI assistant. You will be given a task. You must generate a detailed and long answer.<|im_end|> 
<|im_start|> user
Summarize this article in one sentence. The incident happened on Fife Street between 04:30 BST and 05:00 BST on Saturday. The police have appealed for witnesses to the crash. York Road is closed between its junctions with Skegoneill Avenue and Alexandra Park Avenue. Diversions are in place. Summary:<|im_end|> 
<|im_start|> assistant
 On Saturday between 04:30 BST and 05:00 BST, a severe accident occurred on Fife Street, leading the police to close York Road between Skegoneill Avenue and Alexandra Park Avenue, with diversions in place, and urgently appealing for witnesses to aid in their investigation of the crash.<|im_end|>


In [16]:
lora_model = PeftModel.from_pretrained(base_model, args.output_dir)

In [17]:
with torch.no_grad():
    output = lora_model.generate(**model_input, max_new_tokens=500, temperature=0.7, do_sample=True,
                            eos_token_id=tokenizer.eos_token_id,
                            pad_token_id=tokenizer.pad_token_id)[0]
    print(tokenizer.decode(output))

<s><|im_start|> system
You are an AI assistant. You will be given a task. You must generate a detailed and long answer.<|im_end|> 
<|im_start|> user
Summarize this article in one sentence. The incident happened on Fife Street between 04:30 BST and 05:00 BST on Saturday. The police have appealed for witnesses to the crash. York Road is closed between its junctions with Skegoneill Avenue and Alexandra Park Avenue. Diversions are in place. Summary:<|im_end|> 
<|im_start|> assistant
During the specified timeframe on a Saturday, an unspecified incident took place on Fife Street that led to the subsequent closure of York Road between Skegoneill Avenue and Alexandra Park Avenue while police seek information from potential witnesses; in the meantime, traffic is being redirected via designated diversions in response to the situation.<|im_end|>
