This notebook uses **unsloth**, a framework that helps fine-tuning LLMs faster with less memory.

<a href="https://github.com/unslothai/unsloth"><img src="https://github.com/
unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>

# **Setup**

In [None]:
%%capture
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps "xformers<0.0.26" trl peft accelerate bitsandbytes
!pip install gradio

In [None]:
from typing import List, Tuple, Dict

import gc
import random
import pandas as pd
import torch

from transformers import (
    LlamaForCausalLM,
    MistralForCausalLM,
    GemmaForCausalLM,
    PreTrainedTokenizerFast,
    LlamaTokenizerFast,
    GemmaTokenizerFast,
)
from transformers import TrainingArguments

from huggingface_hub import get_token

from unsloth import FastLanguageModel, PatchDPOTrainer
from peft import PeftModelForCausalLM, LoftQConfig

from datasets import load_dataset, Dataset
from trl import DPOTrainer

import gradio as gr
from gradio import Textbox, Dataframe

# **Functions**

In [None]:
# @title Some params

# Dataset
dataset_name_list = [
    "argilla/distilabel-intel-orca-dpo-pairs",
    "Intel/orca_dpo_pairs",
    "jondurbin/py-dpo-v0.1",
    "jondurbin/truthy-dpo-v0.1",
]

# Model
model_name_list = [
    "unsloth/tinyllama-bnb-4bit",  # very small model for fast testing
    # Llama
    "unsloth/llama-3-8b-bnb-4bit",
    "unsloth/llama-3-8b-Instruct-bnb-4bit",
    "unsloth/llama-2-7b-bnb-4bit",
    "unsloth/yi-6b-bnb-4bit",
    # Mistral
    "unsloth/mistral-7b-bnb-4bit",
    "unsloth/mistral-7b-v0.2-bnb-4bit",
    "unsloth/mistral-7b-instruct-v0.2-bnb-4bit",
    "unsloth/zephyr-sft-bnb-4bit",
    "unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit",
    "unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit",
    # Gemma
    "unsloth/gemma-7b-bnb-4bit",
    "unsloth/gemma-7b-it-bnb-4bit",
    "unsloth/codegemma-7b-bnb-4bit",
]  # More models at https://huggingface.co/unsloth


# Exponents of 2 between 2^start and 2^end both included
def exp2_list(start: int = 0, end: int = 0) -> List[int]:
    """
    Generate a list of powers of 2 from 2^start to 2^end both included.

    Args:
      start (int, optional): The starting exponent. Defaults to 0.
      end (int, optional): The ending exponent. Defaults to 0.

    Returns:
      List[int]: A list containing the powers of 2 from 2^start to 2^end.
    """
    assert isinstance(start, int), "start must be an integer"
    assert isinstance(end, int), "end must be an integer"
    assert start >= 0, "start must be non-negative"
    assert end >= start, "end must be greater than or equal to start"

    return [2**n for n in range(start, end + 1)]


# Target modules to be trained during the fine-tuning
target_module_list = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]

# Gradient checkpointing values
use_gradient_checkpointing_list = ["unsloth", True]

# Optimizers
optim_list = ["adamw_8bit", "paged_adamw_8bit"]

In [None]:
# @title Get dataset


def get_dataset(
    dataset_name: str,
    max_samples: int = 1000,
    test_size: float = 0.2,
    train_test_split_seed: int = 42,
) -> Tuple[Textbox, Textbox, Textbox, Dataset, Dataset]:
    """
    Load a dataset, limit it to a maximum number of samples, transform it to match DPO input requirements, and perform train-test split.

    Args:
        dataset_name (str): The name of the dataset to load.
        max_samples (int, optional): The maximum number of samples to keep from the dataset. Defaults to 1000.
        test_size (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.2.
        train_test_split_seed (int, optional): The random seed for the train-test split. Defaults to 42.

    Returns:
        Tuple[gr.Textbox, gr.Textbox, gr.Textbox, datasets.DatasetDict, datasets.DatasetDict]: A tuple containing three Textbox objects representing a sample's prompt, chosen, and rejected responses, followed by two Datasets objects representing the training and evaluation datasets.
    """
    gr.Info("Getting dataset.")
    try:
        # Load dataset
        dataset = load_dataset(dataset_name, split="train")
        gr.Info("Raw data successfully loaded.")

        # Limit to max_samples rows
        sample_size = min(max_samples, len(dataset))
        dataset = dataset.select(range(sample_size))
        gr.Info(f"{sample_size} rows kept.")

        # Transform the data to match the DPO models' requirements
        # (columns: prompt, chosen, rejected)
        column_names = dataset.column_names
        dataset = dataset.map(
            return_prompt_and_responses,
            num_proc=16,
            remove_columns=column_names,
        )
        gr.Info("Data transformed successfully.")

        # Train test split
        dataset = dataset.train_test_split(
            test_size=test_size, seed=train_test_split_seed
        )
        train_dataset = dataset["train"]
        eval_dataset = dataset["test"]
        gr.Info("Train test split done.")

        # View a data sample
        sample = random.choice(train_dataset)
        sample_prompt = sample["prompt"]
        sample_chosen = sample["chosen"]
        sample_rejected = sample["rejected"]

        sample_prompt = gr.Textbox(value=sample_prompt, visible=True)
        sample_chosen = gr.Textbox(value=sample_chosen, visible=True)
        sample_rejected = gr.Textbox(value=sample_rejected, visible=True)

        gr.Info("Data is ready.")

        return (
            sample_prompt,
            sample_chosen,
            sample_rejected,
            train_dataset,
            eval_dataset,
        )

    except Exception as e:
        raise gr.Error(f"Error getting dataset: {str(e)}")


# Prepare dataset for DPO fine-tuning
def return_prompt_and_responses(samples: Dict[str, str]) -> Dict[str, str]:
    """
    Extract prompt and responses from a dictionary of samples.

    Args:
      samples (Dict[str, str]): A dictionary containing sample data.

    Returns:
      Dict[str, str]: A dictionary containing the prompt and the chosen and rejected responses.
    """
    prompt = samples.get("prompt")
    if prompt is None:
        prompt = samples.get("question")
    if prompt is None:
        prompt = samples.get("input")

    # Apply chat template
    prompt = "<|user|>\n" + prompt + "</s>\n<|assistant|>\n"
    chosen = samples["chosen"] + "</s>\n"
    rejected = samples["rejected"] + "</s>\n"

    return {
        "prompt": prompt,
        "chosen": chosen,
        "rejected": rejected,
    }

In [None]:
# @title Get model


def get_model(
    model_name: str,
    token: str,
    max_seq_length: int = 256,
    load_in_4bit: bool = True,
) -> Tuple[
    LlamaForCausalLM | MistralForCausalLM | GemmaForCausalLM,
    PreTrainedTokenizerFast | LlamaTokenizerFast | GemmaTokenizerFast,
]:
    """
    Load a language model and its tokenizer.

    Args:
        model_name (str): The name or path of the pre-trained model to load.
        token (str): The authentication token for loading the model.
        max_seq_length (int, optional): The maximum sequence length for tokenization. Defaults to 256.
        load_in_4bit (bool, optional): Whether to load the model in 4-bit precision. Defaults to True.

    Returns:
        Tuple[LlamaForCausalLM | MistralForCausalLM | GemmaForCausalLM, PreTrainedTokenizerFast | LlamaTokenizerFast | GemmaTokenizerFast]: A tuple containing the loaded language model and its tokenizer.
    """
    gr.Info("Getting model.")
    try:
        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=model_name,
            token=token,
            max_seq_length=max_seq_length,
            load_in_4bit=load_in_4bit,
            dtype=None,
        )

        gr.Info("Model is ready.")

        return model, tokenizer

    except Exception as e:
        raise gr.Error(f"Error getting model: {str(e)}")

In [None]:
# @title Get PEFT model


def get_peft_model(
    model: LlamaForCausalLM | MistralForCausalLM | GemmaForCausalLM,
    r: int = 8,
    lora_alpha: int = 8,
    peft_model_random_state: int = 42,
    target_modules: List[str] = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    loftq_config: None | LoftQConfig = None,
    use_rslora: bool = False,
    lora_dropout: float = 0,
    bias: str = "none",
    use_gradient_checkpointing: str | bool = "unsloth",
) -> PeftModelForCausalLM:
    """
    Load a PEFT model for causal language modeling.

    Args:
        model (LlamaForCausalLM | MistralForCausalLM | GemmaForCausalLM,): The base model for which PEFT modifications are applied.
        r (int, optional): The number of heads to use in PEFT. Defaults to 8.
        lora_alpha (int, optional): The alpha parameter for LoRA. Defaults to 8.
        peft_model_random_state (int, optional): Random seed for PEFT model initialization. Defaults to 42.
        target_modules (List[str], optional): The list of target modules for PEFT modifications. Defaults to ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"].
        loftq_config (Optional[LoftQConfig], optional): The configuration for LoRA-Fine-Tuning-Aware Quantization. Defaults to None.
        use_rslora (bool, optional): Whether to use rank stabilized LoRA. Defaults to False.
        lora_dropout (float, optional): The dropout rate for LoRA. Defaults to 0.
        bias (str, optional): The type of bias to use. Defaults to "none".
        use_gradient_checkpointing (str | bool, optional): Whether to use gradient checkpointing. Defaults to "unsloth".

    Returns:
        PeftModelForCausalLM: The PEFT model for causal language modeling.
    """
    gr.Info("Getting PEFT model.")
    try:
        peft_model = FastLanguageModel.get_peft_model(
            model,
            r=r,
            lora_alpha=lora_alpha,
            random_state=peft_model_random_state,
            target_modules=target_modules,
            loftq_config=loftq_config,
            use_rslora=use_rslora,
            lora_dropout=lora_dropout,
            bias=bias,
            use_gradient_checkpointing=use_gradient_checkpointing,
        )

        gr.Info("PEFT model is ready.")

        return peft_model

    except Exception as e:
        raise gr.Error(f"Error getting PEFT model: {str(e)}")

In [None]:
# @title Get DPO trainer


def get_dpo_trainer(
    peft_model: PeftModelForCausalLM,
    tokenizer: PreTrainedTokenizerFast | LlamaTokenizerFast | GemmaTokenizerFast,
    train_dataset: Dataset,
    eval_dataset: Dataset,
    per_device_train_batch_size: int = 4,
    num_train_epochs: int = 3,
    learning_rate: float = 5e-6,
    logging_steps: int = 50,
    optim: str = "adamw_8bit",
    dpo_trainer_seed: int = 42,
    beta: float = 0.1,
    max_length: int = 256,
    max_prompt_length: int = 128,
    output_dir: str = "fine-tuned-model",
    push_to_hub: bool = True,
    gradient_accumulation_steps: int = 4,
    warmup_ratio: float = 0.1,
    weight_decay: float = 0.0,
    lr_scheduler_type: str = "linear",
) -> DPOTrainer:
    """
    Load a DPO trainer.

    Args:
        peft_model (PeftModelForCausalLM): The PEFT model for training.
        tokenizer (PreTrainedTokenizerFast | LlamaTokenizerFast | GemmaTokenizerFast): The tokenizer corresponding to the model.
        train_dataset (Dataset): The training dataset.
        eval_dataset (Dataset): The evaluation dataset.
        per_device_train_batch_size (int, optional): Batch size per GPU/CPU for training. Defaults to 4.
        num_train_epochs (int, optional): Number of training epochs. Defaults to 5.
        learning_rate (float, optional): Learning rate. Defaults to 5e-6.
        logging_steps (int, optional): Log every n steps. Defaults to 50.
        optim (str, optional): Optimizer type. Defaults to "adamw_8bit".
        dpo_trainer_seed (int, optional): Seed for DPO trainer initialization. Defaults to 42.
        beta (float, optional): Beta value. Defaults to 0.1.
        max_length (int, optional): Maximum length. Defaults to 256.
        max_prompt_length (int, optional): Maximum prompt length. Defaults to 128.
        output_dir (str, optional): Output directory for saving models and logs. Defaults to "fine-tuned-model".
        push_to_hub (bool, optional): Whether to push to the Hub. Defaults to True.
        gradient_accumulation_steps (int, optional): Number of updates steps to accumulate before performing a backward/update pass. Defaults to 4.
        warmup_ratio (float, optional): Warmup ratio. Defaults to 0.1.
        weight_decay (float, optional): Weight decay. Defaults to 0.0.
        lr_scheduler_type (str, optional): Learning rate scheduler type. Defaults to "linear".

    Returns:
        DPOTrainer: The DPO trainer instance.
    """
    gr.Info("Getting DPO trainer.")
    try:
        # Patch the DPO trainer first
        PatchDPOTrainer()

        dpo_trainer = DPOTrainer(
            model=peft_model,
            ref_model=None,
            tokenizer=tokenizer,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            args=TrainingArguments(
                per_device_train_batch_size=per_device_train_batch_size,
                num_train_epochs=num_train_epochs,
                learning_rate=learning_rate,
                logging_steps=logging_steps,
                optim=optim,
                seed=dpo_trainer_seed,
                output_dir=output_dir,
                push_to_hub=push_to_hub,
                gradient_accumulation_steps=gradient_accumulation_steps,
                warmup_ratio=warmup_ratio,
                weight_decay=weight_decay,
                lr_scheduler_type=lr_scheduler_type,
                fp16=not torch.cuda.is_bf16_supported(),
                bf16=torch.cuda.is_bf16_supported(),
                remove_unused_columns=False,
            ),
            beta=beta,
            max_length=max_length,
            max_prompt_length=max_prompt_length,
        )
        OUTPUT_DIR.value = output_dir

        gr.Info("DPO trainer is ready.")

        return dpo_trainer

    except Exception as e:
        raise gr.Error(f"Error getting DPO trainer: {str(e)}")

In [None]:
# @title Fine-tune model


def finetune_model(
    dpo_trainer: DPOTrainer,
) -> Tuple[LlamaForCausalLM | MistralForCausalLM | GemmaForCausalLM, Dict[str, float]]:
    """
    Fine-tune a DPO model.

    Args:
        dpo_trainer (DPOTrainer): The DPO trainer instance.

    Returns:
        LlamaForCausalLM | MistralForCausalLM | GemmaForCausalLM: The fine-tuned model.
        Dict[str, float]: A dictionary containing training results.
    """
    gr.Info("Training DPO model.")
    try:
        # Fine-tune model
        results = dpo_trainer.train()
        finetuned_model = dpo_trainer.model
        gr.Info("Fine-tuning done.")

        return finetuned_model, results

    except Exception as e:
        raise gr.Error(f"Error in fine-tuning process: {str(e)}")

In [None]:
# @title Save model


def save_model(
    finetuned_model: LlamaForCausalLM | MistralForCausalLM | GemmaForCausalLM,
    token: str,
) -> None:
    """
    Save the fine-tuned model to the Hugging Face Hub.

    Args:
        finetuned_model (LlamaForCausalLM | MistralForCausalLM | GemmaForCausalLM): The fine-tuned model.
        token (str): The authentication token for saving the model to the Hugging Face Hub.

    Returns:
        None.
    """
    gr.Info("Saving DPO model.")
    try:
        # Save fine-tuned model
        finetuned_model.push_to_hub(OUTPUT_DIR.value, token=token)
        gr.Info("Fine-tuned model successfully saved to Hugging Face.")

    except Exception as e:
        raise gr.Error(f"Error in saving process: {str(e)}")

In [None]:
# @title Get training metrics


def get_training_metrics(results: DPOTrainer) -> Dataframe:
    """
    Retrieve training metrics from a DPO trainer and format them.

    Args:
        results (DPOTrainer): The fine-tuned DPO instance.

    Returns:
        Dataframe: A dataframe containing the formatted training metrics.
    """
    gr.Info("Getting and formatting metrics.")
    try:
        training_metrics = {
            metric: [f"{value:.2f}"] for metric, value in results.metrics.items()
        }
        training_metrics = (
            pd.DataFrame(training_metrics, index=["Value"])
            .T.reset_index()
            .rename(columns={"index": "Metric"})
        )
        training_metrics = gr.Dataframe(
            value=training_metrics,
            interactive=False,
            visible=True,
        )

        gr.Info("Metrics successfully retrieved.")

        return training_metrics

    except Exception as e:
        raise gr.Error(f"Error in retrieving metrics: {str(e)}")

In [None]:
# @title Run all


# Run all
def run_all(
    # get_dataset args
    dataset_name,
    max_samples,
    test_size,
    train_test_split_seed,
    # get_model args
    model_name,
    token,
    max_seq_length,
    load_in_4bit,
    # get_peft_mmodel args
    r,
    lora_alpha,
    peft_model_random_state,
    target_modules,
    loftq_config,
    use_rslora,
    lora_dropout,
    bias,
    use_gradient_checkpointing,
    # get_dpo_trainer args
    per_device_train_batch_size,
    num_train_epochs,
    learning_rate,
    logging_steps,
    optim,
    dpo_trainer_seed,
    beta,
    max_length,
    max_prompt_length,
    output_dir,
    push_to_hub,
    gradient_accumulation_steps,
    warmup_ratio,
    weight_decay,
    lr_scheduler_type,
):
    """
    Run all steps including dataset retrieval, model loading, PEFT model creation, DPO trainer instantiation, model fine-tuning and saving, and training metrics retrieval.

    Args:
      `get_dataset` args.
      `get_model` args.
      `get_peft_mmodel` args.
      `get_dpo_trainer` args.

    Returns:
        Tuple containing a sample prompt, its chosen and rejected responses, and the new model's metrics.
    """
    # Get dataset
    sample_prompt, sample_chosen, sample_rejected, train_dataset, eval_dataset = (
        get_dataset(dataset_name, max_samples, test_size, train_test_split_seed)
    )

    # Get model
    model, tokenizer = get_model(
        model_name,
        token,
        max_seq_length,
        load_in_4bit,
    )

    # Get PEFT model
    peft_model = get_peft_model(
        model,
        r,
        lora_alpha,
        peft_model_random_state,
        target_modules,
        loftq_config,
        use_rslora,
        lora_dropout,
        bias,
        use_gradient_checkpointing,
    )

    # Get DPO trainer
    dpo_trainer = get_dpo_trainer(
        peft_model,
        tokenizer,
        train_dataset,
        eval_dataset,
        per_device_train_batch_size,
        num_train_epochs,
        learning_rate,
        logging_steps,
        optim,
        dpo_trainer_seed,
        beta,
        max_length,
        max_prompt_length,
        output_dir,
        push_to_hub,
        gradient_accumulation_steps,
        warmup_ratio,
        weight_decay,
        lr_scheduler_type,
    )

    # Fine-tune model
    finetuned_model, results = finetune_model(dpo_trainer)

    # Save model
    save_model(finetuned_model, token)

    # Get training metrics
    training_metrics = get_training_metrics(results)

    # Free memory
    del model, tokenizer, peft_model, dpo_trainer, finetuned_model
    gc.collect()
    gc.collect()

    return sample_prompt, sample_chosen, sample_rejected, training_metrics

# **UI**

In [None]:
# @title UI

with gr.Blocks() as demo:
    OUTPUT_DIR = gr.State("")

    with gr.Tab("DPO fine-tuning"):
        # GET THE DATA ---------------------------------------------------------
        gr.Markdown("# Get the data")
        with gr.Row():
            dataset_name = gr.Dropdown(
                choices=dataset_name_list,
                value="jondurbin/truthy-dpo-v0.1",
                label="Dataset name",
                interactive=True,
            )
            max_samples = gr.Number(
                value=1000,
                label="Maximum samples to retrieve",
                interactive=True,
                precision=0,
                minimum=100,
                step=1000,
            )
            test_size = gr.Slider(
                minimum=0.1,
                maximum=0.9,
                value=0.2,
                step=0.1,
                label="Test size",
                interactive=True,
            )
            train_test_split_seed = gr.Number(
                value=42,
                label="Train test split seed",
                interactive=True,
                precision=0,
                minimum=0,
            )

        with gr.Accordion(
            label="Data sample (open to see more when fine-tuning is done)", open=False
        ):
            gr.Markdown("### Prompt")
            sample_prompt = gr.Textbox(
                lines=10,
                show_label=False,
                container=False,
                visible=False,
            )

            gr.Markdown("### Chosen")
            sample_chosen = gr.Textbox(
                lines=10,
                show_label=False,
                container=False,
                visible=False,
            )

            gr.Markdown("### Rejected")
            sample_rejected = gr.Textbox(
                lines=10,
                show_label=False,
                container=False,
                visible=False,
            )

        gr.HTML("<br>")

        # GET THE MODEL --------------------------------------------------------
        gr.Markdown("# Get the model")
        with gr.Row():
            model_name = gr.Dropdown(
                choices=model_name_list,
                value="unsloth/zephyr-sft-bnb-4bit",
                label="Model name",
                interactive=True,
            )
            max_seq_length = gr.Number(
                value=256,
                label="Max sequence length (input)",
                interactive=False,
            )
            token = gr.Textbox(
                value=get_token(),
                label="Hugging Face token",
                interactive=False,
                type="password",
            )
            load_in_4bit = gr.Checkbox(
                value=True,
                label="Load in 4bit",
                interactive=False,
            )

        gr.Markdown("---")

        # GET THE PEFT MODEL ---------------------------------------------------
        with gr.Accordion(
            label="PEFT config (open to view and edit LoRA adapters)",
            open=False,
        ):
            with gr.Row():
                r = gr.Dropdown(
                    choices=exp2_list(3, 6),
                    value=2**3,
                    label="Lora rank",
                    scale=1,
                    interactive=True,
                )
                lora_alpha = gr.Dropdown(
                    choices=exp2_list(3, 6),
                    value=2**3,
                    label="Lora alpha",
                    scale=1,
                    interactive=True,
                )
                peft_model_random_state = gr.Number(
                    value=42,
                    label="PEFT model seed",
                    scale=1,
                    interactive=True,
                    precision=0,
                    minimum=0,
                )
                target_modules = gr.Dropdown(
                    choices=target_module_list,
                    value=target_module_list,
                    multiselect=True,
                    label="Target modules to be trained (choose one or many)",
                    scale=5,
                    interactive=True,
                )

            with gr.Row():
                loftq_config = gr.Textbox(
                    value=None,
                    label="LoRA-Fine-Tuning-Aware Quantization",
                    interactive=False,
                    visible=False,
                )
                use_rslora = gr.Checkbox(
                    value=True,
                    label="Rank-stabilized LoRA",
                    interactive=False,
                    visible=False,
                )
                lora_dropout = gr.Number(
                    value=0,
                    label="LoRA dropout",
                    interactive=False,
                    visible=False,
                )
                bias = gr.Textbox(
                    value="none",
                    label="Bias",
                    interactive=False,
                    visible=False,
                )
                use_gradient_checkpointing = gr.Dropdown(
                    choices=use_gradient_checkpointing_list,
                    value="unsloth",
                    label="Use gradient checkpoint",
                    interactive=False,
                    visible=False,
                )

        gr.Markdown("---")

        # GET THE DPO TRAINER --------------------------------------------------
        with gr.Accordion(
            label="DPO trainer (open to view and edit DPO params)",
            open=False,
        ):
            with gr.Row():
                per_device_train_batch_size = gr.Dropdown(
                    choices=exp2_list(0, 2),
                    value=2**2,
                    label="Per device train batch size",
                    interactive=True,
                )
                num_train_epochs = gr.Number(
                    value=3,
                    label="Number of training epochs",
                    interactive=True,
                    precision=0,
                    minimum=1,
                )
                learning_rate = gr.Number(
                    value=5e-6,
                    label="Learning rate",
                    interactive=True,
                    minimum=0,
                )
                logging_steps = gr.Number(
                    value=50,
                    label="Logging steps",
                    interactive=True,
                    minimum=0,
                )
                optim = gr.Dropdown(
                    choices=optim_list,
                    value=optim_list[0],
                    label="Optimizer",
                    interactive=True,
                )
            with gr.Row():
                dpo_trainer_seed = gr.Number(
                    value=42,
                    label="DPO trainer seed",
                    scale=1,
                    interactive=True,
                    precision=0,
                    minimum=0,
                )
                beta = gr.Number(
                    value=0.1,
                    label="Beta",
                    scale=1,
                    interactive=False,
                )
                max_length = gr.Number(
                    value=256,
                    label="Max length",
                    scale=1,
                    interactive=False,
                )
                max_prompt_length = gr.Number(
                    value=128,
                    label="Max prompt length",
                    scale=1,
                    interactive=False,
                )
                output_dir = gr.Textbox(
                    label="HF saving repo",
                    value="fine-tuned-model",
                    scale=2,
                    interactive=False,
                )
                with gr.Column(scale=2):
                    push_to_hub = gr.Checkbox(
                        value=True,
                        label="Push final model to hub",
                        interactive=False,
                    )
                    with gr.Row():
                        fp16 = gr.Checkbox(
                            value=not torch.cuda.is_bf16_supported(),
                            label="fp16",
                            interactive=False,
                        )
                        bf16 = gr.Checkbox(
                            value=torch.cuda.is_bf16_supported(),
                            label="bf16",
                            interactive=False,
                        )

            with gr.Row():
                gradient_accumulation_steps = gr.Number(
                    value=4,
                    label="Gradient accumulation steps",
                    interactive=False,
                    visible=False,
                )
                warmup_ratio = gr.Number(
                    value=0.1,
                    label="Warmup ratio",
                    interactive=False,
                    visible=False,
                )
                weight_decay = gr.Number(
                    value=0.0,
                    label="Weight decay",
                    interactive=False,
                    visible=False,
                )
                lr_scheduler_type = gr.Textbox(
                    value="linear",
                    label="LR scheduler type",
                    interactive=False,
                    visible=False,
                )

        gr.HTML("<br>")

        # GET THE TRAINING METRICS ---------------------------------------------
        training_metrics = gr.Dataframe(
            value=pd.DataFrame(),
            visible=False,
            render=False,
        )

        # RUN ALL --------------------------------------------------------------
        inputs = [
            # get_dataset args
            dataset_name,
            max_samples,
            test_size,
            train_test_split_seed,
            # get_model args
            model_name,
            token,
            max_seq_length,
            load_in_4bit,
            # get_peft_model args
            r,
            lora_alpha,
            peft_model_random_state,
            target_modules,
            loftq_config,
            use_rslora,
            lora_dropout,
            bias,
            use_gradient_checkpointing,
            # get_dpo_trainer args
            per_device_train_batch_size,
            num_train_epochs,
            learning_rate,
            logging_steps,
            optim,
            dpo_trainer_seed,
            beta,
            max_length,
            max_prompt_length,
            output_dir,
            push_to_hub,
            gradient_accumulation_steps,
            warmup_ratio,
            weight_decay,
            lr_scheduler_type,
        ]
        outputs = [
            sample_prompt,
            sample_chosen,
            sample_rejected,
            training_metrics,
        ]

        gr.Markdown("# Fine-tune your model on your data")
        run_all_button = gr.Button(value="Run DPO fine-tuning")
        run_all_button.click(fn=run_all, inputs=inputs, outputs=outputs)
        gc.collect()

    with gr.Tab("Training metrics"):
        training_metrics.render()

demo.launch(share=True, debug=True)