Skip to content

Commit

Permalink
Support precompute ref probs
Browse files Browse the repository at this point in the history
  • Loading branch information
AjayP13 committed Dec 24, 2023
1 parent cb48842 commit 8859ee1
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 8 deletions.
5 changes: 5 additions & 0 deletions src/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from datasets import Dataset, IterableDataset
from datasets.features.features import Features, Value
from datasets.fingerprint import Hasher

from ..datasets.utils import get_column_names
from ..pickling import unpickle_transform
Expand Down Expand Up @@ -82,6 +83,10 @@ def __getitem__(self, key: int | slice | str | Iterable[int]) -> Any:
else:
return self.dataset[key] # type:ignore[attr-defined]

@property
def fingerprint(self) -> Any:
return Hasher.hash((self.step.fingerprint, self.column_names))

def head(self, n=5, shuffle=False, seed=None, buffer_size=1000) -> DataFrame:
if isinstance(self.dataset, Dataset): # type:ignore[attr-defined]
iterable_dataset = (
Expand Down
3 changes: 2 additions & 1 deletion src/steps/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ..utils.arg_utils import DEFAULT, Default
from ..utils.background_utils import run_in_background_process_no_block
from ..utils.collection_utils import uniq_str
from ..utils.fingerprint_utils import stable_fingerprint
from ..utils.fs_utils import move_dir, safe_fn
from ..utils.hf_hub_utils import get_readme_contents, hf_hub_login, prepare_to_publish
from ..utils.time_utils import progress_eta
Expand Down Expand Up @@ -1575,7 +1576,7 @@ def filter_arg_name(arg_name: str) -> bool:
def map_value(val: Any) -> str:
if isinstance(val, _Cachable):
return Hasher.hash((val.version, val._cache_name))
return Hasher.hash(val)
return stable_fingerprint(val)

return Hasher.hash(
[
Expand Down
16 changes: 14 additions & 2 deletions src/trainers/_train_hf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def prepare_for_reward_pairs(row):
else 120,
).output.dataset
elif dpo:
from ._dpo_helper import DPODataCollatorWithPadding
from ._dpo_helper import DPODataCollatorWithPadding # type:ignore[attr-defined]

# Get data collator
data_collator = DPODataCollatorWithPadding(
Expand Down Expand Up @@ -1653,6 +1653,11 @@ def compute_fingerprint(
self,
**kwargs,
) -> str:
def filter_kwargs(arg_name: str) -> bool:
return arg_name not in [
"precompute_ref_log_probs",
]

column_fingerprints = {}
for kwarg in sorted(kwargs.keys()):
if isinstance(
Expand All @@ -1663,6 +1668,7 @@ def compute_fingerprint(
column.step.fingerprint,
column.column_names,
)

fingerprint = Hasher.hash(
[
str(type(self).__name__),
Expand All @@ -1678,7 +1684,13 @@ def compute_fingerprint(
self.quantization_config,
self.peft_config,
column_fingerprints,
stable_fingerprint(kwargs),
stable_fingerprint(
{
kwarg: val
for kwarg, val in kwargs.items()
if filter_kwargs(kwarg)
}
),
]
)
self.fingerprint = fingerprint
Expand Down
75 changes: 70 additions & 5 deletions src/trainers/train_hf_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch

from ..datasets import OutputDatasetColumn, OutputIterableDatasetColumn
from ..steps import Step
from ..steps.step_operations import _INTERNAL_STEP_OPERATION_KEY
from ..utils.arg_utils import AUTO, Default
from ..utils.distributed_utils import is_distributed, validate_device_and_device_map
from ..utils.import_utils import ignore_transformers_warnings, ignore_trl_warnings
Expand All @@ -26,6 +28,20 @@
from transformers.utils.quantization_config import QuantizationConfigMixin


class _PreComputeRefLogProbs(Step):
def setup(self):
self.register_arg("pre_compute_func", help="The pre-compute function.")
self.register_arg(
"dataset_fingerprint", help="The dataset (for fingerprinting purposes)."
)

def run(self):
return self.args["pre_compute_func"]()


setattr(_PreComputeRefLogProbs, _INTERNAL_STEP_OPERATION_KEY, True)


class TrainHFDPO(TrainHFFineTune):
def __init__(
self,
Expand Down Expand Up @@ -104,6 +120,9 @@ def _train( # type:ignore[override]
dpo_beta: float = 0.1,
loss_type: str = "sigmoid",
disable_dropout: bool = True,
# TODO: change to True once bug is fixed:
# https://github.com/huggingface/trl/issues/1139
precompute_ref_log_probs: bool = False,
seed: int = 42,
**kwargs,
):
Expand Down Expand Up @@ -151,8 +170,12 @@ def _train( # type:ignore[override]
)

# We have already tokenized the dataset, so don't let DPOTrainer try to tokenize.
train_dataset.map = lambda *args, **kwargs: train_dataset
validation_dataset.map = lambda *args, **kwargs: validation_dataset
train_dataset.map = ( # type:ignore[method-assign,union-attr]
lambda *args, **kwargs: train_dataset
)
validation_dataset.map = ( # type:ignore[method-assign,union-attr]
lambda *args, **kwargs: validation_dataset
)

# Prepare compute metrics
compute_metrics = kwargs.pop("compute_metrics", None)
Expand Down Expand Up @@ -191,9 +214,13 @@ def _train( # type:ignore[override]
# Prepare model and reference model
self.seed = seed
model = self._create_model()
if self.peft_config:
# DPOTrainer will automatically use the model with the adapters disabled
# as the reference model
if self.peft_config or precompute_ref_log_probs:
# DPOTrainer will automatically use the PEFT model with the adapters disabled
# as the reference model.
# OR...
# If we are pre-computing the ref log probs, they will be computed at the
# beginning of training before the model weights are updataed, so we don't
# need to keep a separate reference model at all.
ref_model = None
else:
ref_model = self._create_model(
Expand Down Expand Up @@ -333,12 +360,50 @@ def _train( # type:ignore[override]
beta=dpo_beta,
loss_type=loss_type,
disable_dropout=disable_dropout,
precompute_ref_log_probs=precompute_ref_log_probs,
generate_during_eval=False,
)
assert trainer.use_dpo_data_collator is False
trainer.use_dpo_data_collator = True
trainer.remove_callback(PrinterCallback)

# Pre-compute ref_log_probs
if precompute_ref_log_probs:

def pre_compute_train():
trainer.get_train_dataloader()
return trainer.train_dataset

trainer.train_datset = _PreComputeRefLogProbs(
"Pre-Compute Reference Log Probs on Train Dataset",
args={
"pre_compute_func": pre_compute_train,
"dataset_fingerprint": [
c.fingerprint
for c in [train_prompts, train_chosen, train_rejected]
],
},
).output.dataset

def pre_compute_eval():
trainer.get_eval_dataloader()
return trainer.eval_dataset

trainer.eval_datset = _PreComputeRefLogProbs(
"Pre-Compute Reference Log Probs on Validation Dataset",
args={
"pre_compute_func": pre_compute_eval,
"dataset_fingerprint": [
c.fingerprint
for c in [
validation_prompts,
validation_chosen,
validation_rejected,
]
],
},
).output.dataset

# Start the trainer
_start_hf_trainer(self, trainer)

Expand Down
3 changes: 3 additions & 0 deletions src/utils/fingerprint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dill
from dill.source import getsource

from datasets import Dataset
from datasets.fingerprint import Hasher

from .. import DataDreamer
Expand Down Expand Up @@ -89,5 +90,7 @@ def stable_fingerprint(value: Any) -> str:
)
elif type(value) is type or callable(value):
return Hasher.hash(re.sub(r" at 0x[0-9a-f]+", " at 0x0", getsource(value)))
elif isinstance(value, Dataset): # pragma: no cover
return value._fingerprint # type:ignore[attr-defined]
else:
return Hasher.hash(value)

0 comments on commit 8859ee1

Please sign in to comment.