<img src="https://cdn.comet.ml/img/notebook_logo.png">

[Comet](https://www.comet.com/site/products/ml-experiment-tracking/?utm_campaign=ray_train&utm_medium=colab) is an MLOps Platform that is designed to help Data Scientists and Teams build better models faster! Comet provides tooling to track, Explain, Manage, and Monitor your models in a single place! It works with Jupyter Notebooks and Scripts and most importantly it's 100% free to get started!

[TRL](https://github.com/huggingface/trl) is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO).

Instrument your runs with Comet to start managing experiments, create dataset versions and track hyperparameters for faster and easier reproducibility and collaboration.

[Find more information about our integration with TRL](https://www.comet.ml/docs/v2/integrations/ml-frameworks/trl/)

Get a preview for what's to come. Check out a completed experiment created from this notebook [here](TODO).

This example is based on the [following Ray Train Lightning example](https://docs.ray.io/en/latest/train/getting-started-pytorch-lightning.html).

# Install Dependencies

In [None]:
%pip install "comet_ml>=3.47.1" "trl>=0.13.0"

# Initialize Comet

In [None]:
import comet_ml

comet_ml.login()

# Import Dependencies

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl import (
    DPOConfig,
    DPOTrainer,
    ModelConfig,
    ScriptArguments,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

# Load your dataset

In [None]:
dataset = load_dataset("trl-lib/ultrafeedback_binarized")

# Train the model

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-0.5B-Instruct",
)
ref_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-0.5B-Instruct",
)

tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen2-0.5B-Instruct",
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
    tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

training_args = DPOConfig(
    output_dir="/tmp",
    learning_rate=5.0e-7,
    max_steps=10,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=1,
    eval_strategy="steps",
    eval_steps=5,
    report_to=["comet_ml"],
)
trainer = DPOTrainer(
    model,
    ref_model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    processing_class=tokenizer,
)

In [None]:
trainer.train()

In [None]:
comet_ml.end()