diff --git a/ludwig/constants.py b/ludwig/constants.py index b72791b536f..33a52d84123 100644 --- a/ludwig/constants.py +++ b/ludwig/constants.py @@ -72,6 +72,7 @@ MEAN_ABSOLUTE_PERCENTAGE_ERROR = "mean_absolute_percentage_error" HUBER = "huber" CORN = "corn" +REWARD = "reward_model_loss" R2 = "r2" EDIT_DISTANCE = "edit_distance" PERPLEXITY = "perplexity" @@ -269,6 +270,7 @@ MODEL_ECD = "ecd" MODEL_GBM = "gbm" MODEL_LLM = "llm" +MODEL_RWD = "rwd" DASK_MODULE_NAME = "dask.dataframe" LUDWIG_VERSION = "ludwig_version" diff --git a/ludwig/data/preprocessing.py b/ludwig/data/preprocessing.py index 6cac5f23f77..7bde94e5898 100644 --- a/ludwig/data/preprocessing.py +++ b/ludwig/data/preprocessing.py @@ -1322,6 +1322,67 @@ def build_dataset( if reshape is not None: proc_cols[proc_column] = backend.df_engine.map_objects(proc_cols[proc_column], lambda x: x.reshape(-1)) + # If training a reward model, prepare the processed columns + if mode == "training" and global_preprocessing_parameters["reward_dataset"] is not None: + reward_parameter_names = [ + "id_column", + "outcome_column", + "chosen_value", + "rejected_value", + "transcript_column", + ] + if not all( + param_name in global_preprocessing_parameters["reward_dataset"] for param_name in reward_parameter_names + ): + raise ValueError(f"Invalid reward training preprocessing parameters, expect {reward_parameter_names}.") + + # Obtain column names and other values + id_column = global_preprocessing_parameters["reward_dataset"]["id_column"] + outcome_column = global_preprocessing_parameters["reward_dataset"]["outcome_column"] + chosen_value = global_preprocessing_parameters["reward_dataset"]["chosen_value"] + rejected_value = global_preprocessing_parameters["reward_dataset"]["rejected_value"] + transcript_column = global_preprocessing_parameters["reward_dataset"]["transcript_column"] + + # Validate the input configuration + if not all( + [ + len(config["input_features"]) == 1, + len(config["output_features"]) == 1, + config["input_features"][0]["name"] == transcript_column, + config["input_features"][0]["type"] == "text", + config["output_features"][0]["name"] == id_column, + config["output_features"][0]["type"] == "number", + ] + ): + raise ValueError(f"Invalid reward model training configuration, received {config}.") + + # Validate the input dataframe's columns + dataset_columns_expected = sorted([id_column, outcome_column, transcript_column]) + dataset_columns_actual = sorted(dataset_df.columns) + if "split" in dataset_columns_actual: + dataset_columns_actual.remove("split") + if dataset_columns_actual != dataset_columns_expected: + raise ValueError( + f"Invalid reward training input dataset, expect columns {dataset_columns_expected}, " + f"got columns {dataset_columns_actual}." + ) + + # Validate the processed dataset columns + id_column = config["output_features"][0]["proc_column"] + transcript_column = config["input_features"][0]["proc_column"] + proc_columns_expected = sorted([id_column, transcript_column]) + proc_columns_actual = sorted(proc_cols.keys()) + if "split" in proc_columns_actual: + proc_columns_actual.remove("split") + if proc_columns_actual != proc_columns_expected: + raise ValueError( + f"Invalid reward training processed dataset, expect columns {proc_columns_expected}, " + f"got columns {proc_columns_actual}." + ) + + # Add the outcome column to processed columns + proc_cols[outcome_column] = dataset_df[outcome_column] + # Implements an outer join of proc_cols dataset = backend.df_engine.df_like(dataset_df, proc_cols) @@ -1356,6 +1417,45 @@ def build_dataset( # Embed features with fixed encoders dataset = embed_fixed_features(dataset, feature_configs, metadata, backend) + # If training a reward model, perform grouping and joining on dataset + if mode == "training" and global_preprocessing_parameters["reward_dataset"] is not None: + + def parse_id_rows_group(rows_group): + rows_idxs = rows_group.index + + # Retrieve the outcome of the rows in this group + if len(rows_idxs) != 2: + raise ValueError( + f"Incorrect number of text rows for session ID {rows_group.name} when processing the " + f"reward model training dataset: expect 2 rows per session ID, got {len(rows_idxs)} rows." + ) + outcome_first = dataset.loc[rows_idxs[0]][outcome_column] + outcome_second = dataset.loc[rows_idxs[1]][outcome_column] + if not any( + [ + outcome_first == chosen_value and outcome_second == rejected_value, + outcome_first == rejected_value and outcome_second == chosen_value, + ] + ): + raise ValueError( + f"Incorrect labeling of the 2 text rows for session ID {rows_group.name} when processing " + f"the reward model training dataset: expect one row to be labeled as {chosen_value}, " + f"and one row to be labeled as {rejected_value}, but got {outcome_first} and {outcome_second}." + ) + + # Return the text transcripts for row in specific order + if outcome_first == chosen_value: + return [rows_group.loc[rows_idxs[0]], rows_group.loc[rows_idxs[1]]] + else: + return [rows_group.loc[rows_idxs[1]], rows_group.loc[rows_idxs[0]]] + + # Group dataset rows by ID, aggregate group data + dataset_id_groups = dataset.groupby(id_column) + dataset_refactored = dataset_id_groups[transcript_column].apply(parse_id_rows_group).reset_index() + if "split" in dataset.columns: + dataset_refactored["split"] = dataset_id_groups["split"].apply(lambda x: list(x)[0]).reset_index()["split"] + dataset = dataset_refactored + return dataset, metadata diff --git a/ludwig/models/ecd.py b/ludwig/models/ecd.py index d8d921598d2..fb12411d625 100644 --- a/ludwig/models/ecd.py +++ b/ludwig/models/ecd.py @@ -6,7 +6,7 @@ import torch from ludwig.combiners.combiners import create_combiner -from ludwig.constants import MODEL_ECD +from ludwig.constants import MODEL_ECD, MODEL_RWD from ludwig.globals import MODEL_WEIGHTS_FILE_NAME from ludwig.models.base import BaseModel from ludwig.schema.model_types.ecd import ECDModelConfig @@ -177,3 +177,16 @@ def get_augmentation_pipelines(self) -> AugmentationPipelines: ).get_augmentation_pipeline() return AugmentationPipelines(augmentation_pipelines) + + +class RWD(ECD): + """This class represents a Reward Model, a model type that takes as input some feature (i.e. text) and predicts + a single scalar output representing the reward/preference of that input. + + This model type is used for applications such as RLHF fine-tuning of LLMs. This model class is a subclass of ECD, + and uses most of ECD's code and pathways. + """ + + @staticmethod + def type() -> str: + return MODEL_RWD diff --git a/ludwig/models/registry.py b/ludwig/models/registry.py index 5cb724cf229..30cecf5b884 100644 --- a/ludwig/models/registry.py +++ b/ludwig/models/registry.py @@ -1,7 +1,7 @@ import logging -from ludwig.constants import MODEL_ECD, MODEL_GBM, MODEL_LLM -from ludwig.models.ecd import ECD +from ludwig.constants import MODEL_ECD, MODEL_GBM, MODEL_LLM, MODEL_RWD +from ludwig.models.ecd import ECD, RWD from ludwig.models.llm import LLM logger = logging.getLogger(__name__) @@ -24,4 +24,5 @@ def gbm(*args, **kwargs): MODEL_ECD: ECD, MODEL_GBM: gbm, MODEL_LLM: LLM, + MODEL_RWD: RWD, } diff --git a/ludwig/modules/loss_modules.py b/ludwig/modules/loss_modules.py index 62be167ef57..7ec5657b4a0 100644 --- a/ludwig/modules/loss_modules.py +++ b/ludwig/modules/loss_modules.py @@ -35,6 +35,7 @@ MAPELossConfig, MSELossConfig, NextTokenSoftmaxCrossEntropyLossConfig, + RewardLossConfig, RMSELossConfig, RMSPELossConfig, SequenceSoftmaxCrossEntropyLossConfig, @@ -264,3 +265,14 @@ def __init__(self, config: CORNLossConfig): def forward(self, preds: Tensor, target: Tensor) -> Tensor: num_classes = preds.shape[1] return corn_loss(preds, target, num_classes=num_classes) + + +@register_loss(RewardLossConfig) +class RewardLoss(nn.Module, LogitsInputsMixin): + """Reward loss.""" + + def __init__(self, config: RewardLossConfig): + super().__init__() + + def forward(self, chosen_reward: Tensor, rejected_reward: Tensor) -> Tensor: + return -1 * nn.functional.logsigmoid(chosen_reward - rejected_reward).mean() diff --git a/ludwig/schema/decoders/base.py b/ludwig/schema/decoders/base.py index f1e27833fd8..d142cb0a908 100644 --- a/ludwig/schema/decoders/base.py +++ b/ludwig/schema/decoders/base.py @@ -2,7 +2,18 @@ from typing import Dict, List, Tuple, Union from ludwig.api_annotations import DeveloperAPI -from ludwig.constants import BINARY, CATEGORY, MODEL_ECD, MODEL_GBM, MODEL_LLM, NUMBER, SET, TIMESERIES, VECTOR +from ludwig.constants import ( + BINARY, + CATEGORY, + MODEL_ECD, + MODEL_GBM, + MODEL_LLM, + MODEL_RWD, + NUMBER, + SET, + TIMESERIES, + VECTOR, +) from ludwig.schema import common_fields from ludwig.schema import utils as schema_utils from ludwig.schema.decoders.utils import register_decoder_config @@ -108,7 +119,7 @@ def module_name(cls): @DeveloperAPI -@register_decoder_config("regressor", [BINARY, NUMBER], model_types=[MODEL_ECD, MODEL_GBM]) +@register_decoder_config("regressor", [BINARY, NUMBER], model_types=[MODEL_ECD, MODEL_GBM, MODEL_RWD]) @ludwig_dataclass class RegressorConfig(BaseDecoderConfig): """RegressorConfig is a dataclass that configures the parameters used for a regressor decoder.""" diff --git a/ludwig/schema/encoders/text_encoders.py b/ludwig/schema/encoders/text_encoders.py index acfbad5469f..51d09933983 100644 --- a/ludwig/schema/encoders/text_encoders.py +++ b/ludwig/schema/encoders/text_encoders.py @@ -1,7 +1,7 @@ from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union from ludwig.api_annotations import DeveloperAPI -from ludwig.constants import MODEL_ECD, MODEL_GBM, TEXT +from ludwig.constants import MODEL_ECD, MODEL_GBM, MODEL_RWD, TEXT from ludwig.error import ConfigValidationError from ludwig.schema import utils as schema_utils from ludwig.schema.encoders.sequence_encoders import SequenceEncoderConfig @@ -3144,7 +3144,7 @@ def module_name(): @DeveloperAPI -@register_encoder_config("tf_idf", TEXT, model_types=[MODEL_ECD, MODEL_GBM]) +@register_encoder_config("tf_idf", TEXT, model_types=[MODEL_ECD, MODEL_GBM, MODEL_RWD]) @ludwig_dataclass class TfIdfEncoderConfig(SequenceEncoderConfig): type: str = schema_utils.ProtectedString("tf_idf") diff --git a/ludwig/schema/features/loss/loss.py b/ludwig/schema/features/loss/loss.py index f4ec1472e9d..48c905c44fc 100644 --- a/ludwig/schema/features/loss/loss.py +++ b/ludwig/schema/features/loss/loss.py @@ -12,6 +12,7 @@ MEAN_SQUARED_ERROR, NEXT_TOKEN_SOFTMAX_CROSS_ENTROPY, NUMBER, + REWARD, ROOT_MEAN_SQUARED_ERROR, ROOT_MEAN_SQUARED_PERCENTAGE_ERROR, SEQUENCE, @@ -475,3 +476,30 @@ def class_weights(self) -> int: @property def class_similarities_temperature(self) -> int: return 0 + + +@DeveloperAPI +@register_loss([NUMBER]) +@ludwig_dataclass +class RewardLossConfig(BaseLossConfig): + type: str = schema_utils.ProtectedString( + REWARD, + description=( + "This loss function is used to train reward models in Ludwig, for the purposes of RLHF. The " + "reward model will typically be an LLM or other large Transformer, with a single numerical output " + "feature. To train these models, data is provided in terms of pairs of responses (texts), one " + "of which is chosen and one of which is rejected, representing a human ranking assessment. The " + "model is trained using a contrastive loss procedure, maximizing the difference between the reward " + "score of the chosen text and the score of the rejected text." + ), + ) + + weight: float = schema_utils.NonNegativeFloat( + default=1.0, + description="Weight of the loss.", + parameter_metadata=LOSS_METADATA["RewardLoss"]["weight"], + ) + + @classmethod + def name(self) -> str: + return "Reward Model Loss" diff --git a/ludwig/schema/features/number_feature.py b/ludwig/schema/features/number_feature.py index 97ea49123c6..20dfa84e8bd 100644 --- a/ludwig/schema/features/number_feature.py +++ b/ludwig/schema/features/number_feature.py @@ -1,7 +1,7 @@ from typing import List, Tuple, Union from ludwig.api_annotations import DeveloperAPI -from ludwig.constants import MEAN_SQUARED_ERROR, MODEL_ECD, MODEL_GBM, NUMBER +from ludwig.constants import MEAN_SQUARED_ERROR, MODEL_ECD, MODEL_GBM, MODEL_RWD, NUMBER from ludwig.schema import utils as schema_utils from ludwig.schema.decoders.base import BaseDecoderConfig from ludwig.schema.decoders.utils import DecoderDataclassField @@ -21,6 +21,7 @@ gbm_output_config_registry, input_mixin_registry, output_mixin_registry, + rwd_output_config_registry, ) from ludwig.schema.metadata import FEATURE_METADATA from ludwig.schema.metadata.parameter_metadata import INTERNAL_ONLY @@ -153,6 +154,17 @@ class ECDNumberOutputFeatureConfig(NumberOutputFeatureConfig): ) +@DeveloperAPI +@rwd_output_config_registry.register(NUMBER) +@ludwig_dataclass +class RWDNumberOutputFeatureConfig(NumberOutputFeatureConfig): + decoder: BaseDecoderConfig = DecoderDataclassField( + MODEL_RWD, + feature_type=NUMBER, + default="regressor", + ) + + @DeveloperAPI @gbm_output_config_registry.register(NUMBER) @ludwig_dataclass diff --git a/ludwig/schema/features/text_feature.py b/ludwig/schema/features/text_feature.py index be87778bbee..a7f23030a07 100644 --- a/ludwig/schema/features/text_feature.py +++ b/ludwig/schema/features/text_feature.py @@ -4,6 +4,7 @@ MODEL_ECD, MODEL_GBM, MODEL_LLM, + MODEL_RWD, NEXT_TOKEN_SOFTMAX_CROSS_ENTROPY, PERPLEXITY, SEQUENCE_SOFTMAX_CROSS_ENTROPY, @@ -30,6 +31,7 @@ llm_input_config_registry, llm_output_config_registry, output_mixin_registry, + rwd_input_config_registry, ) from ludwig.schema.metadata import FEATURE_METADATA from ludwig.schema.metadata.parameter_metadata import INTERNAL_ONLY @@ -91,6 +93,17 @@ class LLMTextInputFeatureConfig(TextInputFeatureConfig): ) +@DeveloperAPI +@rwd_input_config_registry.register(TEXT) +@ludwig_dataclass +class RWDTextInputFeatureConfig(TextInputFeatureConfig): + encoder: BaseEncoderConfig = EncoderDataclassField( + MODEL_RWD, + feature_type=TEXT, + default="tf_idf", + ) + + @DeveloperAPI @output_mixin_registry.register(TEXT) @ludwig_dataclass diff --git a/ludwig/schema/features/utils.py b/ludwig/schema/features/utils.py index 34abd2eee15..8c4914f691e 100644 --- a/ludwig/schema/features/utils.py +++ b/ludwig/schema/features/utils.py @@ -1,7 +1,7 @@ from collections import defaultdict from ludwig.api_annotations import DeveloperAPI -from ludwig.constants import MODEL_ECD, MODEL_GBM, MODEL_LLM +from ludwig.constants import MODEL_ECD, MODEL_GBM, MODEL_LLM, MODEL_RWD from ludwig.schema import utils as schema_utils from ludwig.utils.registry import Registry @@ -11,10 +11,12 @@ ecd_input_config_registry = input_config_registries[MODEL_ECD] gbm_input_config_registry = input_config_registries[MODEL_GBM] llm_input_config_registry = input_config_registries[MODEL_LLM] +rwd_input_config_registry = input_config_registries[MODEL_RWD] ecd_output_config_registry = output_config_registries[MODEL_ECD] gbm_output_config_registry = output_config_registries[MODEL_GBM] llm_output_config_registry = output_config_registries[MODEL_LLM] +rwd_output_config_registry = output_config_registries[MODEL_RWD] input_mixin_registry = Registry() output_mixin_registry = Registry() diff --git a/ludwig/schema/metadata/configs/loss.yaml b/ludwig/schema/metadata/configs/loss.yaml index 240491bac00..50d83a6bb07 100644 --- a/ludwig/schema/metadata/configs/loss.yaml +++ b/ludwig/schema/metadata/configs/loss.yaml @@ -52,3 +52,6 @@ SigmoidCrossEntropyLoss: expected_impact: 3 weight: expected_impact: 2 +RewardLoss: + weight: + expected_impact: 2 diff --git a/ludwig/schema/metadata/configs/preprocessing.yaml b/ludwig/schema/metadata/configs/preprocessing.yaml index 11218989437..e657ae7c132 100644 --- a/ludwig/schema/metadata/configs/preprocessing.yaml +++ b/ludwig/schema/metadata/configs/preprocessing.yaml @@ -146,3 +146,9 @@ cache_encoder_embeddings: it's not always the case that you would always want to enable it when possible. expected_impact: 1 ui_display_name: Cache Encoder Embeddings +reward_dataset: + id_column: The name of the reward model training dataset session ID (and reward placeholder) column + outcome_column: The name of the reward model training dataset human-labeled chosen/rejected outcome column + chosen_value: The value of the string in the outcome column corresponding to chosen samples + rejected_value: The value of the string in the outcome column corresponding to rejected samples + transcript_column: The name of the reward model training dataset input text transcript to train with column diff --git a/ludwig/schema/model_types/ecd.py b/ludwig/schema/model_types/ecd.py index 967d12ae143..b39da66a616 100644 --- a/ludwig/schema/model_types/ecd.py +++ b/ludwig/schema/model_types/ecd.py @@ -36,3 +36,12 @@ class ECDModelConfig(ModelConfig): preprocessing: PreprocessingConfig = PreprocessingField().get_default_field() defaults: ECDDefaultsConfig = ECDDefaultsField().get_default_field() hyperopt: Optional[HyperoptConfig] = HyperoptField().get_default_field() + + +@DeveloperAPI +@register_model_type(name="rwd") +@ludwig_dataclass +class RWDModelConfig(ECDModelConfig): + """Parameters for RWD (Reward Model).""" + + model_type: str = schema_utils.ProtectedString("rwd") diff --git a/ludwig/schema/preprocessing.py b/ludwig/schema/preprocessing.py index 30127893240..d50c58c2cf4 100644 --- a/ludwig/schema/preprocessing.py +++ b/ludwig/schema/preprocessing.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + from ludwig.api_annotations import DeveloperAPI from ludwig.constants import RANDOM from ludwig.schema import utils as schema_utils @@ -38,6 +40,13 @@ class PreprocessingConfig(schema_utils.BaseMarshmallowConfig): default=RANDOM, ) + reward_dataset: Dict[str, Any] = schema_utils.Dict( + default=None, + allow_none=True, + description="If not None, the input dataset will be preprocessed to train an RLHF reward model.", + parameter_metadata=PREPROCESSING_METADATA["reward_dataset"], + ) + @DeveloperAPI class PreprocessingField(schema_utils.DictMarshmallowField): diff --git a/ludwig/schema/trainer.py b/ludwig/schema/trainer.py index 60cdb4caace..21111bb0d50 100644 --- a/ludwig/schema/trainer.py +++ b/ludwig/schema/trainer.py @@ -12,6 +12,7 @@ MODEL_ECD, MODEL_GBM, MODEL_LLM, + MODEL_RWD, TRAINING, ) from ludwig.error import ConfigValidationError @@ -811,6 +812,18 @@ class FineTuneTrainerConfig(ECDTrainerConfig): ) +@DeveloperAPI +@register_trainer_schema(MODEL_RWD) +@ludwig_dataclass +class RWDTrainerConfig(ECDTrainerConfig): + """Dataclass that configures most of the hyperparameters used for Reward Model training.""" + + base_learning_rate: float = schema_utils.NonNegativeFloat( + default=0.0, + description="Base learning rate used for training in the Reward Model trainer.", + ) + + @DeveloperAPI def get_model_type_jsonschema(model_type: str = MODEL_ECD): enum = [MODEL_ECD] diff --git a/ludwig/trainers/__init__.py b/ludwig/trainers/__init__.py index e25dba8b547..37a9e656391 100644 --- a/ludwig/trainers/__init__.py +++ b/ludwig/trainers/__init__.py @@ -12,3 +12,9 @@ import ludwig.trainers.trainer_llm # noqa: F401 except ImportError: pass + + +try: + import ludwig.trainers.trainer_rlhf # noqa: F401 +except ImportError: + pass diff --git a/ludwig/trainers/trainer_rlhf.py b/ludwig/trainers/trainer_rlhf.py new file mode 100644 index 00000000000..011b89b5f66 --- /dev/null +++ b/ludwig/trainers/trainer_rlhf.py @@ -0,0 +1,148 @@ +import logging +from typing import Dict, List, Optional, Tuple + +import torch + +from ludwig.constants import MODEL_RWD +from ludwig.distributed.base import DistributedStrategy +from ludwig.models.ecd import RWD +from ludwig.modules.loss_modules import RewardLoss +from ludwig.schema.trainer import RWDTrainerConfig +from ludwig.trainers.registry import register_trainer +from ludwig.trainers.trainer import Trainer +from ludwig.utils.batch_size_tuner import BatchSizeEvaluator +from ludwig.utils.defaults import default_random_seed + +logger = logging.getLogger(__name__) + + +@register_trainer(MODEL_RWD) +class RWDTrainer(Trainer): + """This class trains models of type Reward Model.""" + + @staticmethod + def get_schema_cls(): + return RWDTrainerConfig + + def __init__( + self, + config: RWDTrainerConfig, + model: RWD, + resume: float = False, + skip_save_model: bool = False, + skip_save_progress: bool = False, + skip_save_log: bool = False, + callbacks: List = None, + report_tqdm_to_ray=False, + random_seed: float = default_random_seed, + distributed: Optional[DistributedStrategy] = None, + device: Optional[str] = None, + **kwargs, + ): + super().__init__( + config, + model, + resume, + skip_save_model, + skip_save_progress, + skip_save_log, + callbacks, + report_tqdm_to_ray, + random_seed, + distributed, + device, + **kwargs, + ) + + # Save the reward model loss function + self.reward_loss_function = RewardLoss({}) + + def _create_batch_size_evaluator(self) -> BatchSizeEvaluator: + trainer = self + + class _TrainerBatchSizeEvaluator(BatchSizeEvaluator): + def reset(self): + trainer.model.reset_metrics() + trainer.optimizer.zero_grad() + + def step(self, batch_size: int): + trainer.distributed.set_batch_size(trainer.dist_model, batch_size) + inputs = { + input_feature_name: [ + input_feature.create_sample_input(batch_size=batch_size).to(trainer.device), + input_feature.create_sample_input(batch_size=batch_size).to(trainer.device), + ] + for input_feature_name, input_feature in trainer.model.input_features.items() + } + targets = { + output_feature_name: output_feature.create_sample_output(batch_size=batch_size).to(trainer.device) + for output_feature_name, output_feature in trainer.model.output_features.items() + } + trainer.train_step(inputs, targets) + + return _TrainerBatchSizeEvaluator() + + def train_step( + self, inputs: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], should_step: bool = True + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """Performs a single training step of the RLHF reward model. + + Params: + inputs: A dictionary of input data, from feature name to tensor. + targets: A dictionary of target data, from feature name to tensor. + should_step: Whether to perform a step of the optimizer after computing gradients. + + Returns: + A tuple of the loss tensor and a dictionary of loss for every output feature. + """ + if self.use_amp is True: + raise ValueError("Invalid trainer arguments for RLHF reward model") + + # Validate inputs and targets + if not len(inputs) == 1: + raise ValueError(f"Invalid reward model training data inputs, expect 1 input feature, got {len(inputs)}.") + if not len(targets) == 1: + raise ValueError( + f"Invalid reward model training data targets, expect 1 target feature, got {len(targets)}." + ) + id_column = list(targets.keys())[0] + transcript_column = list(inputs.keys())[0] + + # Run forward-propagation of the chosen and rejected inputs + with self.distributed.prepare_model_update(self.dist_model, should_step=should_step): + # Obtain model predictions and loss + inputs_chosen = {transcript_column: inputs[transcript_column][0]} + inputs_rejected = {transcript_column: inputs[transcript_column][1]} + model_output_chosen = self.dist_model(inputs_chosen) + model_output_rejected = self.dist_model(inputs_rejected) + logits_chosen = model_output_chosen[f"{id_column}::logits"] + logits_rejected = model_output_rejected[f"{id_column}::logits"] + loss = self.reward_loss_function(logits_chosen, logits_rejected) + loss = loss / self.gradient_accumulation_steps + all_losses = {"reward_loss": loss} + + # Begin the backward pass + variables = self.dist_model.parameters() + self.distributed.backward(loss, self.dist_model) + + if not should_step: + # Short-circuit the parameter updates if we are still accumulating gradients + return loss, all_losses + + # Wait for gradient aggregation to complete before clipping the gradients + # When using AMP, we need to do this before unscaling. + # See: https://github.com/horovod/horovod/blob/master/examples/pytorch/pytorch_mnist.py + self.distributed.wait_optimizer_synced(self.optimizer) + + if self.distributed.allow_clip_gradients(): + # Clip gradients + self.clip_grads(variables) + + # Apply gradient updates + with self.distributed.prepare_optimizer_update(self.optimizer): + # Because we already synchronized above, we skip doing so here + self.distributed.step(self.optimizer) + + self.distributed.zero_grad(self.optimizer) + + return loss, all_losses diff --git a/tests/integration_tests/test_preprocessing.py b/tests/integration_tests/test_preprocessing.py index 1bf0610443e..fadbd159174 100644 --- a/tests/integration_tests/test_preprocessing.py +++ b/tests/integration_tests/test_preprocessing.py @@ -163,6 +163,60 @@ def test_strip_whitespace_category(csv_filename, tmpdir): assert len(np.unique(train_ds.dataset[cat_feat[PROC_COLUMN]])) == cat_feat[DECODER]["vocab_size"] +def test_rlhf_reward_model_data_preprocessor(): + id_column = "reward_session_id" + outcome_column = "outcome" + chosen_value = "some_value_1" + rejected_value = "some_value_2" + transcript_column = "transcript" + + # Define the features + input_features = [ + text_feature( + name=transcript_column, + encoder={"type": "auto_transformer", "pretrained_model_name_or_path": "bert-base-uncased"}, + ) + ] + output_features = [number_feature(name=id_column)] + backend = LocalTestBackend() + config = {"input_features": input_features, "output_features": output_features} + + # Generate random dataframe + dataframe = generate_data_as_dataframe(input_features, output_features, num_examples=20) + + # Add reward model training pairs + dataframe[id_column] = dataframe.index // 2 + dataframe[outcome_column] = np.where(dataframe.index % 2, rejected_value, chosen_value) + + # Modify config with preprocessing + config["preprocessing"] = { + "reward_dataset": { + "id_column": id_column, + "outcome_column": outcome_column, + "chosen_value": chosen_value, + "rejected_value": rejected_value, + "transcript_column": transcript_column, + } + } + config["model_type"] = "rwd" + + # Run preprocessing, get output dataset + ludwig_model = LudwigModel(config, backend=backend) + train_dataset, _, _, metadata = ludwig_model.preprocess(dataset=dataframe) + + # Validate the processed dataset columns + dataset = train_dataset.dataset + id_column = config["output_features"][0]["proc_column"] + transcript_column = config["input_features"][0]["proc_column"] + dataset_columns_expected = sorted([id_column, transcript_column]) + dataset_columns_actual = sorted(dataset.keys()) + assert dataset_columns_actual == dataset_columns_expected + + # Validate each row in the processed dataset + for row_id in range(len(dataset[id_column])): + assert len(dataset[transcript_column][row_id]) == 2 + + @pytest.mark.parametrize( "backend", [ diff --git a/tests/integration_tests/test_trainer.py b/tests/integration_tests/test_trainer.py index b7f623875e8..b47f2e41be8 100644 --- a/tests/integration_tests/test_trainer.py +++ b/tests/integration_tests/test_trainer.py @@ -17,6 +17,7 @@ binary_feature, category_feature, generate_data, + generate_data_as_dataframe, LocalTestBackend, number_feature, RAY_BACKEND_CONFIG, @@ -36,6 +37,7 @@ from ludwig.data.dataset.ray import RayDataset from ludwig.models.gbm import GBM + from ludwig.modules.loss_modules import RewardLoss from ludwig.schema.model_config import ModelConfig from ludwig.schema.trainer import GBMTrainerConfig from ludwig.trainers.trainer_lightgbm import LightGBMRayTrainer @@ -216,6 +218,62 @@ def test_changing_parameters_on_plateau(tmpdir): model.train(training_set=data_csv, validation_set=val_csv, test_set=test_csv, output_directory=tmpdir) +def test_rlhf_reward_model_trainer(tmpdir): + id_column = "reward_session_id" + outcome_column = "outcome" + chosen_value = "some_value_1" + rejected_value = "some_value_2" + transcript_column = "transcript" + + # Define the features + input_features = [ + text_feature( + name=transcript_column, + encoder={"type": "auto_transformer", "pretrained_model_name_or_path": "gpt2", "trainable": True}, + ) + ] + output_features = [number_feature(name=id_column)] + backend = LocalTestBackend() + config = {"input_features": input_features, "output_features": output_features} + + # Generate random dataframe + dataframe = generate_data_as_dataframe(input_features, output_features, num_examples=20) + + # Add reward model training pairs + dataframe[id_column] = dataframe.index // 2 + dataframe[outcome_column] = np.where(dataframe.index % 2, rejected_value, chosen_value) + + # Modify config with preprocessing + config["preprocessing"] = { + "reward_dataset": { + "id_column": id_column, + "outcome_column": outcome_column, + "chosen_value": chosen_value, + "rejected_value": rejected_value, + "transcript_column": transcript_column, + } + } + config[TRAINER] = { + "epochs": 2, + BATCH_SIZE: 4, + "learning_rate": 1.0, + } + config["model_type"] = "rwd" + + # Train Ludwig model with the dataset + ludwig_model = LudwigModel(config, backend=backend) + ludwig_model.train(training_set=dataframe, output_directory=tmpdir) + + +def test_rlhf_reward_model_loss(): + reward_loss_function = RewardLoss({}) + + # Test the reward loss function + assert reward_loss_function(torch.tensor(100.0), torch.tensor(50.0)) < torch.tensor(1e-15) + assert reward_loss_function(torch.tensor(50.0), torch.tensor(100.0)) > torch.tensor(10) + assert reward_loss_function(torch.tensor(100.0), torch.tensor(100.0)) > torch.tensor(0.4) + + @pytest.mark.distributed def test_lightgbm_dataset_partition(ray_cluster_2cpu): # Create a LightGBM model with a Ray backend