Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RLHF Reward Trainer and Loss #3435

Open
wants to merge 30 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2bc1d01
Make preprocessing modifications V1
asdataminer Jun 5, 2023
d6cc331
Add dataset validation
asdataminer Jun 5, 2023
ea07940
Small edit
asdataminer Jun 5, 2023
e993d7b
Add tests
asdataminer Jun 7, 2023
98db29b
Small edits
asdataminer Jun 8, 2023
518feca
Another small edit
asdataminer Jun 8, 2023
8ac5101
Modify processing strategy
asdataminer Jun 12, 2023
b61a174
Small edit
asdataminer Jun 12, 2023
e78152d
Another small edit
asdataminer Jun 12, 2023
440cbec
Small edit
asdataminer Jun 12, 2023
b12f8ac
Add loss items
asdataminer Jun 8, 2023
936194d
Add trainer
asdataminer Jun 8, 2023
2b74e7b
Modify reward model trainer
asdataminer Jun 13, 2023
becf832
Small edits
asdataminer Jun 13, 2023
7d4243f
Add trainer, data edits
asdataminer Jun 13, 2023
316e2bf
Add schema changes
asdataminer Jun 14, 2023
dd675d6
Add refactored processing logic and trainer
asdataminer Jun 14, 2023
cae42ad
Style edits
asdataminer Jun 14, 2023
a0808cf
Modify tests
asdataminer Jun 14, 2023
9b46959
More test edits
asdataminer Jun 15, 2023
cc40f7c
Make reward model a separate model type
asdataminer Jun 15, 2023
a93a1b8
Additional refactor edits
asdataminer Jun 15, 2023
e311741
Style edits
asdataminer Jun 15, 2023
1662827
Add text encoder
asdataminer Jun 15, 2023
4860265
Small edits
asdataminer Jun 15, 2023
9cf18d6
Modify trainer, tests passing
asdataminer Jun 15, 2023
82bbffe
Bug fix
asdataminer Jun 20, 2023
3675f42
Small edit
asdataminer Jun 20, 2023
4f852ea
Another edit
asdataminer Jun 20, 2023
b6ef5d1
Reward loss test
asdataminer Jun 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
MEAN_ABSOLUTE_PERCENTAGE_ERROR = "mean_absolute_percentage_error"
HUBER = "huber"
CORN = "corn"
REWARD = "reward_model_loss"
R2 = "r2"
EDIT_DISTANCE = "edit_distance"
PERPLEXITY = "perplexity"
Expand Down Expand Up @@ -269,6 +270,7 @@
MODEL_ECD = "ecd"
MODEL_GBM = "gbm"
MODEL_LLM = "llm"
MODEL_RWD = "rwd"
DASK_MODULE_NAME = "dask.dataframe"
LUDWIG_VERSION = "ludwig_version"

Expand Down
100 changes: 100 additions & 0 deletions ludwig/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,67 @@ def build_dataset(
if reshape is not None:
proc_cols[proc_column] = backend.df_engine.map_objects(proc_cols[proc_column], lambda x: x.reshape(-1))

# If training a reward model, prepare the processed columns
if mode == "training" and global_preprocessing_parameters["reward_dataset"] is not None:
reward_parameter_names = [
"id_column",
"outcome_column",
"chosen_value",
"rejected_value",
"transcript_column",
]
if not all(
param_name in global_preprocessing_parameters["reward_dataset"] for param_name in reward_parameter_names
):
raise ValueError(f"Invalid reward training preprocessing parameters, expect {reward_parameter_names}.")

# Obtain column names and other values
id_column = global_preprocessing_parameters["reward_dataset"]["id_column"]
outcome_column = global_preprocessing_parameters["reward_dataset"]["outcome_column"]
chosen_value = global_preprocessing_parameters["reward_dataset"]["chosen_value"]
rejected_value = global_preprocessing_parameters["reward_dataset"]["rejected_value"]
transcript_column = global_preprocessing_parameters["reward_dataset"]["transcript_column"]

# Validate the input configuration
if not all(
[
len(config["input_features"]) == 1,
len(config["output_features"]) == 1,
config["input_features"][0]["name"] == transcript_column,
config["input_features"][0]["type"] == "text",
config["output_features"][0]["name"] == id_column,
config["output_features"][0]["type"] == "number",
]
):
raise ValueError(f"Invalid reward model training configuration, received {config}.")

# Validate the input dataframe's columns
dataset_columns_expected = sorted([id_column, outcome_column, transcript_column])
dataset_columns_actual = sorted(dataset_df.columns)
if "split" in dataset_columns_actual:
dataset_columns_actual.remove("split")
if dataset_columns_actual != dataset_columns_expected:
raise ValueError(
f"Invalid reward training input dataset, expect columns {dataset_columns_expected}, "
f"got columns {dataset_columns_actual}."
)

# Validate the processed dataset columns
id_column = config["output_features"][0]["proc_column"]
transcript_column = config["input_features"][0]["proc_column"]
proc_columns_expected = sorted([id_column, transcript_column])
proc_columns_actual = sorted(proc_cols.keys())
if "split" in proc_columns_actual:
proc_columns_actual.remove("split")
if proc_columns_actual != proc_columns_expected:
raise ValueError(
f"Invalid reward training processed dataset, expect columns {proc_columns_expected}, "
f"got columns {proc_columns_actual}."
)

# Add the outcome column to processed columns
proc_cols[outcome_column] = dataset_df[outcome_column]

# Implements an outer join of proc_cols
dataset = backend.df_engine.df_like(dataset_df, proc_cols)

Expand Down Expand Up @@ -1356,6 +1417,45 @@ def build_dataset(
# Embed features with fixed encoders
dataset = embed_fixed_features(dataset, feature_configs, metadata, backend)

# If training a reward model, perform grouping and joining on dataset
if mode == "training" and global_preprocessing_parameters["reward_dataset"] is not None:

def parse_id_rows_group(rows_group):
rows_idxs = rows_group.index

# Retrieve the outcome of the rows in this group
if len(rows_idxs) != 2:
raise ValueError(
f"Incorrect number of text rows for session ID {rows_group.name} when processing the "
f"reward model training dataset: expect 2 rows per session ID, got {len(rows_idxs)} rows."
)
outcome_first = dataset.loc[rows_idxs[0]][outcome_column]
outcome_second = dataset.loc[rows_idxs[1]][outcome_column]
if not any(
[
outcome_first == chosen_value and outcome_second == rejected_value,
outcome_first == rejected_value and outcome_second == chosen_value,
]
):
raise ValueError(
f"Incorrect labeling of the 2 text rows for session ID {rows_group.name} when processing "
f"the reward model training dataset: expect one row to be labeled as {chosen_value}, "
f"and one row to be labeled as {rejected_value}, but got {outcome_first} and {outcome_second}."
)

# Return the text transcripts for row in specific order
if outcome_first == chosen_value:
return [rows_group.loc[rows_idxs[0]], rows_group.loc[rows_idxs[1]]]
else:
return [rows_group.loc[rows_idxs[1]], rows_group.loc[rows_idxs[0]]]

# Group dataset rows by ID, aggregate group data
dataset_id_groups = dataset.groupby(id_column)
dataset_refactored = dataset_id_groups[transcript_column].apply(parse_id_rows_group).reset_index()
if "split" in dataset.columns:
dataset_refactored["split"] = dataset_id_groups["split"].apply(lambda x: list(x)[0]).reset_index()["split"]
dataset = dataset_refactored

return dataset, metadata


Expand Down
15 changes: 14 additions & 1 deletion ludwig/models/ecd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from ludwig.combiners.combiners import create_combiner
from ludwig.constants import MODEL_ECD
from ludwig.constants import MODEL_ECD, MODEL_RWD
from ludwig.globals import MODEL_WEIGHTS_FILE_NAME
from ludwig.models.base import BaseModel
from ludwig.schema.model_types.ecd import ECDModelConfig
Expand Down Expand Up @@ -177,3 +177,16 @@ def get_augmentation_pipelines(self) -> AugmentationPipelines:
).get_augmentation_pipeline()

return AugmentationPipelines(augmentation_pipelines)


class RWD(ECD):
"""This class represents a Reward Model, a model type that takes as input some feature (i.e. text) and predicts
a single scalar output representing the reward/preference of that input.

This model type is used for applications such as RLHF fine-tuning of LLMs. This model class is a subclass of ECD,
and uses most of ECD's code and pathways.
"""

@staticmethod
def type() -> str:
return MODEL_RWD
5 changes: 3 additions & 2 deletions ludwig/models/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from ludwig.constants import MODEL_ECD, MODEL_GBM, MODEL_LLM
from ludwig.models.ecd import ECD
from ludwig.constants import MODEL_ECD, MODEL_GBM, MODEL_LLM, MODEL_RWD
from ludwig.models.ecd import ECD, RWD
from ludwig.models.llm import LLM

logger = logging.getLogger(__name__)
Expand All @@ -24,4 +24,5 @@ def gbm(*args, **kwargs):
MODEL_ECD: ECD,
MODEL_GBM: gbm,
MODEL_LLM: LLM,
MODEL_RWD: RWD,
}
12 changes: 12 additions & 0 deletions ludwig/modules/loss_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
MAPELossConfig,
MSELossConfig,
NextTokenSoftmaxCrossEntropyLossConfig,
RewardLossConfig,
RMSELossConfig,
RMSPELossConfig,
SequenceSoftmaxCrossEntropyLossConfig,
Expand Down Expand Up @@ -264,3 +265,14 @@ def __init__(self, config: CORNLossConfig):
def forward(self, preds: Tensor, target: Tensor) -> Tensor:
num_classes = preds.shape[1]
return corn_loss(preds, target, num_classes=num_classes)


@register_loss(RewardLossConfig)
class RewardLoss(nn.Module, LogitsInputsMixin):
"""Reward loss."""

def __init__(self, config: RewardLossConfig):
super().__init__()

def forward(self, chosen_reward: Tensor, rejected_reward: Tensor) -> Tensor:
return -1 * nn.functional.logsigmoid(chosen_reward - rejected_reward).mean()
15 changes: 13 additions & 2 deletions ludwig/schema/decoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@
from typing import Dict, List, Tuple, Union

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import BINARY, CATEGORY, MODEL_ECD, MODEL_GBM, MODEL_LLM, NUMBER, SET, TIMESERIES, VECTOR
from ludwig.constants import (
BINARY,
CATEGORY,
MODEL_ECD,
MODEL_GBM,
MODEL_LLM,
MODEL_RWD,
NUMBER,
SET,
TIMESERIES,
VECTOR,
)
from ludwig.schema import common_fields
from ludwig.schema import utils as schema_utils
from ludwig.schema.decoders.utils import register_decoder_config
Expand Down Expand Up @@ -108,7 +119,7 @@ def module_name(cls):


@DeveloperAPI
@register_decoder_config("regressor", [BINARY, NUMBER], model_types=[MODEL_ECD, MODEL_GBM])
@register_decoder_config("regressor", [BINARY, NUMBER], model_types=[MODEL_ECD, MODEL_GBM, MODEL_RWD])
@ludwig_dataclass
class RegressorConfig(BaseDecoderConfig):
"""RegressorConfig is a dataclass that configures the parameters used for a regressor decoder."""
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/encoders/text_encoders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import MODEL_ECD, MODEL_GBM, TEXT
from ludwig.constants import MODEL_ECD, MODEL_GBM, MODEL_RWD, TEXT
from ludwig.error import ConfigValidationError
from ludwig.schema import utils as schema_utils
from ludwig.schema.encoders.sequence_encoders import SequenceEncoderConfig
Expand Down Expand Up @@ -3144,7 +3144,7 @@ def module_name():


@DeveloperAPI
@register_encoder_config("tf_idf", TEXT, model_types=[MODEL_ECD, MODEL_GBM])
@register_encoder_config("tf_idf", TEXT, model_types=[MODEL_ECD, MODEL_GBM, MODEL_RWD])
@ludwig_dataclass
class TfIdfEncoderConfig(SequenceEncoderConfig):
type: str = schema_utils.ProtectedString("tf_idf")
Expand Down
28 changes: 28 additions & 0 deletions ludwig/schema/features/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
MEAN_SQUARED_ERROR,
NEXT_TOKEN_SOFTMAX_CROSS_ENTROPY,
NUMBER,
REWARD,
ROOT_MEAN_SQUARED_ERROR,
ROOT_MEAN_SQUARED_PERCENTAGE_ERROR,
SEQUENCE,
Expand Down Expand Up @@ -475,3 +476,30 @@ def class_weights(self) -> int:
@property
def class_similarities_temperature(self) -> int:
return 0


@DeveloperAPI
@register_loss([NUMBER])
@ludwig_dataclass
class RewardLossConfig(BaseLossConfig):
type: str = schema_utils.ProtectedString(
REWARD,
description=(
"This loss function is used to train reward models in Ludwig, for the purposes of RLHF. The "
"reward model will typically be an LLM or other large Transformer, with a single numerical output "
"feature. To train these models, data is provided in terms of pairs of responses (texts), one "
"of which is chosen and one of which is rejected, representing a human ranking assessment. The "
"model is trained using a contrastive loss procedure, maximizing the difference between the reward "
"score of the chosen text and the score of the rejected text."
),
)

weight: float = schema_utils.NonNegativeFloat(
default=1.0,
description="Weight of the loss.",
parameter_metadata=LOSS_METADATA["RewardLoss"]["weight"],
)

@classmethod
def name(self) -> str:
return "Reward Model Loss"
14 changes: 13 additions & 1 deletion ludwig/schema/features/number_feature.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Tuple, Union

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import MEAN_SQUARED_ERROR, MODEL_ECD, MODEL_GBM, NUMBER
from ludwig.constants import MEAN_SQUARED_ERROR, MODEL_ECD, MODEL_GBM, MODEL_RWD, NUMBER
from ludwig.schema import utils as schema_utils
from ludwig.schema.decoders.base import BaseDecoderConfig
from ludwig.schema.decoders.utils import DecoderDataclassField
Expand All @@ -21,6 +21,7 @@
gbm_output_config_registry,
input_mixin_registry,
output_mixin_registry,
rwd_output_config_registry,
)
from ludwig.schema.metadata import FEATURE_METADATA
from ludwig.schema.metadata.parameter_metadata import INTERNAL_ONLY
Expand Down Expand Up @@ -153,6 +154,17 @@ class ECDNumberOutputFeatureConfig(NumberOutputFeatureConfig):
)


@DeveloperAPI
@rwd_output_config_registry.register(NUMBER)
@ludwig_dataclass
class RWDNumberOutputFeatureConfig(NumberOutputFeatureConfig):
decoder: BaseDecoderConfig = DecoderDataclassField(
MODEL_RWD,
feature_type=NUMBER,
default="regressor",
)


@DeveloperAPI
@gbm_output_config_registry.register(NUMBER)
@ludwig_dataclass
Expand Down
13 changes: 13 additions & 0 deletions ludwig/schema/features/text_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
MODEL_ECD,
MODEL_GBM,
MODEL_LLM,
MODEL_RWD,
NEXT_TOKEN_SOFTMAX_CROSS_ENTROPY,
PERPLEXITY,
SEQUENCE_SOFTMAX_CROSS_ENTROPY,
Expand All @@ -30,6 +31,7 @@
llm_input_config_registry,
llm_output_config_registry,
output_mixin_registry,
rwd_input_config_registry,
)
from ludwig.schema.metadata import FEATURE_METADATA
from ludwig.schema.metadata.parameter_metadata import INTERNAL_ONLY
Expand Down Expand Up @@ -91,6 +93,17 @@ class LLMTextInputFeatureConfig(TextInputFeatureConfig):
)


@DeveloperAPI
@rwd_input_config_registry.register(TEXT)
@ludwig_dataclass
class RWDTextInputFeatureConfig(TextInputFeatureConfig):
encoder: BaseEncoderConfig = EncoderDataclassField(
MODEL_RWD,
feature_type=TEXT,
default="tf_idf",
)


@DeveloperAPI
@output_mixin_registry.register(TEXT)
@ludwig_dataclass
Expand Down
4 changes: 3 additions & 1 deletion ludwig/schema/features/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import MODEL_ECD, MODEL_GBM, MODEL_LLM
from ludwig.constants import MODEL_ECD, MODEL_GBM, MODEL_LLM, MODEL_RWD
from ludwig.schema import utils as schema_utils
from ludwig.utils.registry import Registry

Expand All @@ -11,10 +11,12 @@
ecd_input_config_registry = input_config_registries[MODEL_ECD]
gbm_input_config_registry = input_config_registries[MODEL_GBM]
llm_input_config_registry = input_config_registries[MODEL_LLM]
rwd_input_config_registry = input_config_registries[MODEL_RWD]

ecd_output_config_registry = output_config_registries[MODEL_ECD]
gbm_output_config_registry = output_config_registries[MODEL_GBM]
llm_output_config_registry = output_config_registries[MODEL_LLM]
rwd_output_config_registry = output_config_registries[MODEL_RWD]

input_mixin_registry = Registry()
output_mixin_registry = Registry()
Expand Down
3 changes: 3 additions & 0 deletions ludwig/schema/metadata/configs/loss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ SigmoidCrossEntropyLoss:
expected_impact: 3
weight:
expected_impact: 2
RewardLoss:
weight:
expected_impact: 2
6 changes: 6 additions & 0 deletions ludwig/schema/metadata/configs/preprocessing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,9 @@ cache_encoder_embeddings:
it's not always the case that you would always want to enable it when possible.
expected_impact: 1
ui_display_name: Cache Encoder Embeddings
reward_dataset:
id_column: The name of the reward model training dataset session ID (and reward placeholder) column
outcome_column: The name of the reward model training dataset human-labeled chosen/rejected outcome column
chosen_value: The value of the string in the outcome column corresponding to chosen samples
rejected_value: The value of the string in the outcome column corresponding to rejected samples
transcript_column: The name of the reward model training dataset input text transcript to train with column