# Direct Preference Optimization: Fine Tuning a LLM Using Preference Data
* Author: mark@datarobot.com
* Date: 2026-02-10

## Summary

This notebook outlines how to take preference data and use that to update a model. It will take a dataset of query, good response, bad response and use that to update a model using DPO, in a single session.

1. Download the preference data from DataRobot
2. Train a model using Direct Preference Optimization (DPO)
3. Upload the new model weights to DataRobot ready to register and then deploy. 



## Setup

### Import libraries

In [None]:
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
import datarobot as dr
import torch

from accelerate import notebook_launcher


### Bind variables

In [None]:
# These variables can aso be fetched from a secret store or config files
DATAROBOT_ENDPOINT="https://app.eu.datarobot.com/api/v2"
# The URL may vary depending on your hosting preference, the above example is for DataRobot EU Managed AI Cloud

DATAROBOT_API_TOKEN="<INSERT YOUR DataRobot API Token>"
# The API Token can be found by click the avatar icon and then </> Developer Tools
 

### Connect to DataRobot

You can read more about different options for [connecting to DataRobot from the client](https://docs.datarobot.com/en/docs/api/api-quickstart/api-qs.html).

In [None]:
dr.Client()

## Download Data

This is section will download a precreated dataset from the DataRobot registry.

In [None]:
def download_registry_file(dataset_id, local_path):
    # Retrieve the dataset object from the registry
    dataset = dr.Dataset.get(dataset_id)

    # Download the file
    print(f"Downloading {dataset.name}...")
    dataset.get_file(local_path)
    print(f"File saved to: {local_path}")


In [None]:
DATASET_ID = '<DATASET_IT'
DATASET_NAME = 'preference_training_dataset.csv'

## Fine-Tuning with DPO

This section uses trl and HuggingFace accelrate to perform Direct Preference Optimization (https://arxiv.org/abs/2305.18290) 

This example is designed to run on 4 A10s. 

In [None]:
MODEL_ID = 'Qwen/Qwen2-0.5B-Instruct'
TMP_DIR = '/tmp/qwen2-0.5b-dpo'
OUTPUT_DIR = '/home/notebooks/storage/qwen2-0.5b-dpo'
SHARD_WRAP_CLASS = "Qwen2DecoderLayer"

In [None]:
 def dpo_train():
    # 1. Load Dataset (Format: prompt, chosen, rejected)
    dataset = load_dataset("csv", data_files=DATASET_NAME)

    # 2. Load Model & Tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, 
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    # 3. DPO Configuration
    training_args = DPOConfig(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=4, # Increase this if VRAM allows
        gradient_accumulation_steps=4,
        learning_rate=5e-7,
        lr_scheduler_type="cosine",
        logging_steps=1,
        max_steps=500,
        bf16=True,
        fsdp="full_shard auto_wrap",
        fsdp_config={
            "transformer_layer_cls_to_wrap": SHARD_WRAP_CLASS,
            "fsdp_state_dict_type": "FULL_STATE_DICT",
            "fsdp_offload_params": True,               # Move gathered weights to CPU
        },
        gradient_checkpointing=True,
        remove_unused_columns=False,
        logging_dir=LOGGING_DIR,          # Where TensorBoard events will be saved
        report_to=["tensorboard"],       # Enables TensorBoard logging
        logging_first_step=True,
        # Turn this OFF to avoid the ValueError during the run
        save_only_model=False,
        save_strategy="no",
    
    )


    # 4. Initialize Trainer
    trainer = DPOTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        processing_class=tokenizer,
    )

    trainer.train()


    # 1. Wait for all processes to catch up
    trainer.accelerator.wait_for_everyone()

    # 2. Set FSDP to gather weights for a single file
    if trainer.is_fsdp_enabled:
        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
        from torch.distributed.fsdp import StateDictType, FullStateDictConfig
   
        # This force-sets the plugin to FULL_STATE_DICT for the final save
        trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
    
        # Save the model
        # Rank 0 will handle the consolidation automatically here
        trainer.save_model(OUTPUT_DIR)
        # Configure FSDP to output a full state dict (not shards)
        save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
        with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
            cpu_state_dict = model.state_dict()
    
        # 3. Save only on Rank 0
        if trainer.accelerator.is_main_process:
            # Ensure the tokenizer is also saved there
            tokenizer.save_pretrained(OUTPUT_DIR)
    else:
        # If not using FSDP, standard save
        if trainer.accelerator.is_main_process:
            trainer.save_model(OUTPUT_DIR)

In [None]:

notebook_launcher(dpo_train)

### Upload to DataRobot Workshop

In [None]:
RUNTIME_ID = '662d6a54ef58f64c5a07d122'
CUSTOM_MODEL_NAME = 'DPO_Trained_Model

In [None]:
def upload_to_custom_workshop(model_name, local_folder_path, runtime_id):
    # 1. Create the Custom Model shell
    custom_model = dr.CustomInferenceModel.create(
        name=model_name,
        target_type=dr.TARGET_TYPE.TEXT_GENERATION, # Options: BINARY, REGRESSION, MULTICLASS
        target_name="promptText",
        )

    # 2. Upload files and create a version
    # 'local_folder_path' should contain your model.py, requirements.txt, etc.
    model_version = dr.CustomModelVersion.create_clean(
        custom_model_id=custom_model.id,
        base_environment_id=runtime_id,
        folder_path=local_folder_path
    )

    print(f"Model Created: {custom_model.id}")
    print(f"Version Created: {model_version.id}")
    return model_version

In [None]:
    upload_to_custom_workshop(CUSTOM_MODEL_NAME, OUTPUT_DIR, RUNTIME_ID)
