In [1]:
import nb_utils

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"

In [3]:
# |export
import os, re
from typing import Dict, Literal

import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer

from accelerate import Accelerator
from huggingface_hub import list_repo_files
from huggingface_hub.utils._validators import HFValidationError
from peft import LoraConfig, PeftConfig

from training_lib.configs import DataArguments, ModelArguments

In [4]:
# |export
DEFAULT_CHAT_TEMPLATE = """\
{% for message in messages %}
{% if message['role'] == 'user' %}
{{ '<|user|>\n' + message['content'] + eos_token }}
{% elif message['role'] == 'system' %}
{{ '<|system|>\n' + message['content'] + eos_token }}
{% elif message['role'] == 'assistant' %}
{{ '<|assistant|>\n'  + message['content'] + eos_token }}
{% endif %}
{% if loop.last and add_generation_prompt %}
{{ '<|assistant|>' }}
{% endif %}
{% endfor %}"""

## Tokenizer

In [5]:
# |export
def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
    """Get the tokenizer for the model."""
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)

    # If no pad token, default to eos token
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    if data_args.truncation_side is not None:
        tokenizer.truncation_side = data_args.truncation_side

    # Set reasonable default for models without max length
    if tokenizer.model_max_length > 100_000:
        tokenizer.model_max_length = 2048

    if data_args.chat_template is not None:
        tokenizer.chat_template = data_args.chat_template
    elif tokenizer.chat_template is None:
        tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE

    return tokenizer

In [6]:
data_args = DataArguments(dataset_mixer={"HuggingFaceH4/ultrachat_200k": 1.0})
model_args = ModelArguments(model_name_or_path="mistralai/Mistral-7B-v0.1")

tokenizer = get_tokenizer(model_args, data_args)

print(tokenizer)
print("=" * 80)
print(tokenizer.chat_template)

LlamaTokenizerFast(name_or_path='mistralai/Mistral-7B-v0.1', vocab_size=32000, model_max_length=2048, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
{% for message in messages %}
{% if message['role'] == 'user' %}
{{ '<|user|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'system' %}
{{ '<|system|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'assistant' %}
{{ '<|assistant|>
'  + message['content'] + eos_token }}
{% endif %}
{% if loop.last and add_generation_prompt %}
{

In [7]:
# |export
def apply_chat_template(example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"] = "sft", assistant_prefix="<|assistant|>\n"):
    def _strip_prefix(s, pattern):
        # Use re.escape to escape any special characters in the pattern
        return re.sub(f"^{re.escape(pattern)}", "", s)

    if task in ["sft", "generation"]:
        messages = example["messages"]

        # We add an empty system message if there is none
        if messages[0]["role"] != "system":
            messages.insert(0, {"role": "system", "content": ""})

        example["text"] = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True if task == "generation" else False
        )
    elif task == "rm":
        if all(k in example.keys() for k in ("chosen", "rejected")):
            chosen_messages = example["chosen"]
            rejected_messages = example["rejected"]

            # We add an empty system message if there is none
            if chosen_messages[0]["role"] != "system":
                chosen_messages.insert(0, {"role": "system", "content": ""})
            if rejected_messages[0]["role"] != "system":
                rejected_messages.insert(0, {"role": "system", "content": ""})

            example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
            example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
        else:
            raise ValueError(
                f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
            )
    elif task == "dpo":
        if all(k in example.keys() for k in ("chosen", "rejected")):
            # Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token
            prompt_messages = [[msg for msg in example["chosen"] if msg["role"] == "user"][0]]

            # Insert system message
            if example["chosen"][0]["role"] != "system":
                prompt_messages.insert(0, {"role": "system", "content": ""})
            else:
                prompt_messages.insert(0, example["chosen"][0])

            # TODO: handle case where chosen/rejected also have system messages
            chosen_messages = example["chosen"][1:]
            rejected_messages = example["rejected"][1:]

            example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
            example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
            example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True)
            example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix)
            example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix)
        else:
            raise ValueError(
                f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
            )
    else:
        raise ValueError(f"Task {task} not supported, please ensure that the provided task is one of {['sft', 'generation', 'rm', 'dpo']}")
    return example

In [8]:
from training_lib.data import get_datasets

data_args = DataArguments(dataset_mixer={"HuggingFaceH4/ultrachat_200k": 1.0})
model_args = ModelArguments(model_name_or_path="mistralai/Mistral-7B-v0.1")

raw_datasets = get_datasets(data_args, splits=["train_sft", "test_sft"])
tokenizer = get_tokenizer(model_args, data_args)

print(raw_datasets)
print("=" * 80)

example = raw_datasets["train"][0]
print(example)
print("=" * 80)

res = apply_chat_template(example, tokenizer, "sft")
print(res["text"])
print("=" * 80)

res = apply_chat_template(example, tokenizer, "generation")
print(res["text"])

DatasetDict({
    train: Dataset({
        features: ['prompt', 'prompt_id', 'messages'],
        num_rows: 207865
    })
    test: Dataset({
        features: ['prompt', 'prompt_id', 'messages'],
        num_rows: 23110
    })
})
{'prompt': 'How does the location of the Sydney Conservatorium of Music impact the academic and professional opportunities available to music students, and how does the conservatorium support student engagement with the music industry in Australia?', 'prompt_id': 'bc82021755d49d219f182fdd76ccfbd97ec9db38b1d12e1b891434e1477057f1', 'messages': [{'content': 'How does the location of the Sydney Conservatorium of Music impact the academic and professional opportunities available to music students, and how does the conservatorium support student engagement with the music industry in Australia?', 'role': 'user'}, {'content': "The location of the Sydney Conservatorium of Music, which is situated in the heart of Sydney's cultural precinct, impacts both the academic an

In [9]:
from training_lib.data import get_datasets

data_args = DataArguments(dataset_mixer={"HuggingFaceH4/ultrafeedback_binarized": 1.0})
model_args = ModelArguments(model_name_or_path="mistralai/Mistral-7B-v0.1")

raw_datasets = get_datasets(data_args, splits=["train_prefs", "test_prefs"])
tokenizer = get_tokenizer(model_args, data_args)

print(raw_datasets)
print("=" * 80)

example = raw_datasets["train"][0]
print(example)
print("=" * 80)

res = apply_chat_template(example, tokenizer, "rm")
print({k: v for k, v in res.items() if k in ["score_chosen", "score_rejected", "text_chosen", "text_rejected"]})
print("=" * 80)

res = apply_chat_template(example, tokenizer, "dpo")
print({k: v for k, v in res.items() if k in ["text_prompt", "text_chosen", "text_rejected"]})
print("=" * 80)

DatasetDict({
    train: Dataset({
        features: ['prompt', 'prompt_id', 'chosen', 'rejected', 'messages', 'score_chosen', 'score_rejected'],
        num_rows: 61966
    })
    test: Dataset({
        features: ['prompt', 'prompt_id', 'chosen', 'rejected', 'messages', 'score_chosen', 'score_rejected'],
        num_rows: 2000
    })
})
{'prompt': 'Please provide the content structure of the following text using [Latex] data format. Additionally, please ensure that the structure includes a table comparing the results of the two groups, as well as a code snippet detailing the statistical analysis performed in the study. The results of the study showed a significant difference between the two groups. The study also found that the treatment was effective for a wide range of conditions.', 'prompt_id': '5cf991718c2849e6d0312f999314dfb0e6dfc1a10234f461ea321b9c038262be', 'chosen': [{'content': 'Please provide the content structure of the following text using [Latex] data format. Additionall

## Model

### Utilities

In [10]:
# |export
def get_current_device() -> int:
    """Get the current device. For GPU we return the local process index to enable multiple GPU training."""
    return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"

In [11]:
get_current_device()

0

In [12]:
# |export
def get_kbit_device_map() -> Dict[str, int] | None:
    """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
    return {"": get_current_device()} if torch.cuda.is_available() else None

In [13]:
get_kbit_device_map()

{'': 0}

### Quantization

In [14]:
# |export
def get_quantization_config(model_args) -> BitsAndBytesConfig | None:
    if model_args.load_in_4bit:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,  # For consistency with model weights, we use the same value as `torch_dtype` which is float16 for PEFT models
            bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
            bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
        )
    elif model_args.load_in_8bit:
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
        )
    else:
        quantization_config = None

    return quantization_config

In [16]:
model_args = ModelArguments(model_name_or_path="mistralai/Mistral-7B-v0.1", load_in_4bit=True)
get_quantization_config(model_args)

BitsAndBytesConfig {
  "bnb_4bit_compute_dtype": "float16",
  "bnb_4bit_quant_type": "nf4",
  "bnb_4bit_use_double_quant": false,
  "llm_int8_enable_fp32_cpu_offload": false,
  "llm_int8_has_fp16_weight": false,
  "llm_int8_skip_modules": null,
  "llm_int8_threshold": 6.0,
  "load_in_4bit": true,
  "load_in_8bit": false,
  "quant_method": "bitsandbytes"
}

In [17]:
model_args = ModelArguments(model_name_or_path="mistralai/Mistral-7B-v0.1", load_in_8bit=True)
get_quantization_config(model_args)

BitsAndBytesConfig {
  "bnb_4bit_compute_dtype": "float32",
  "bnb_4bit_quant_type": "fp4",
  "bnb_4bit_use_double_quant": false,
  "llm_int8_enable_fp32_cpu_offload": false,
  "llm_int8_has_fp16_weight": false,
  "llm_int8_skip_modules": null,
  "llm_int8_threshold": 6.0,
  "load_in_4bit": false,
  "load_in_8bit": true,
  "quant_method": "bitsandbytes"
}

### PEFT

In [18]:
# |export
def get_peft_config(model_args: ModelArguments) -> PeftConfig | None:
    if model_args.use_peft is False:
        return None

    peft_config = LoraConfig(
        r=model_args.lora_r,
        lora_alpha=model_args.lora_alpha,
        lora_dropout=model_args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=model_args.lora_target_modules,
        modules_to_save=model_args.lora_modules_to_save,
    )

    return peft_config

In [22]:
model_args = ModelArguments(model_name_or_path="mistralai/Mistral-7B-v0.1", load_in_8bit=True, use_peft=True, lora_r=64)
peft_config = get_peft_config(model_args)
peft_config.__dict__

{'peft_type': <PeftType.LORA: 'LORA'>,
 'auto_mapping': None,
 'base_model_name_or_path': None,
 'revision': None,
 'task_type': 'CAUSAL_LM',
 'inference_mode': False,
 'r': 64,
 'target_modules': None,
 'lora_alpha': 32,
 'lora_dropout': 0.05,
 'fan_in_fan_out': False,
 'bias': 'none',
 'modules_to_save': None,
 'init_lora_weights': True,
 'layers_to_transform': None,
 'layers_pattern': None,
 'rank_pattern': {},
 'alpha_pattern': {}}

In [24]:
#|export
def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
    try:
        # Try first if model on a Hub repo
        repo_files = list_repo_files(model_name_or_path, revision=revision)
    except HFValidationError:
        # If not, check local repo
        repo_files = os.listdir(model_name_or_path)
        
    return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files

## Export

In [25]:
from nbdev.export import nb_export

nb_export("20_model_utils.ipynb", lib_path="../training_lib/", name="model_utils")