Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 15, 2023
1 parent a93a1b8 commit f6c84b1
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
10 changes: 6 additions & 4 deletions ludwig/models/ecd.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,13 @@ def get_augmentation_pipelines(self) -> AugmentationPipelines:


class RWD(ECD):
"""This class represents a Reward Model, a model that inputs some feature (i.e. text transcript) and predicts a
single scalar output representing the reward/preference of that input.
This model type is used for applications such as RLHF fine-tuning of LLMs. This model class is a subclass of ECD,
and uses most of ECD's code and pathways.
"""
This class represents a Reward Model, a model that inputs some feature (i.e. text transcript) and predicts
a single scalar output representing the reward/preference of that input. This model type is used for applications
such as RLHF fine-tuning of LLMs. This model class is a subclass of ECD, and uses most of ECD's code and pathways.
"""

@staticmethod
def type() -> str:
return MODEL_RWD
12 changes: 11 additions & 1 deletion ludwig/schema/decoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import (
BINARY, CATEGORY, MODEL_ECD, MODEL_GBM, MODEL_LLM, MODEL_RWD, NUMBER, SET, TIMESERIES, VECTOR)
BINARY,
CATEGORY,
MODEL_ECD,
MODEL_GBM,
MODEL_LLM,
MODEL_RWD,
NUMBER,
SET,
TIMESERIES,
VECTOR,
)
from ludwig.schema import common_fields
from ludwig.schema import utils as schema_utils
from ludwig.schema.decoders.utils import register_decoder_config
Expand Down
2 changes: 1 addition & 1 deletion ludwig/schema/features/number_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
ecd_defaults_config_registry,
ecd_input_config_registry,
ecd_output_config_registry,
rwd_output_config_registry,
gbm_defaults_config_registry,
gbm_input_config_registry,
gbm_output_config_registry,
input_mixin_registry,
output_mixin_registry,
rwd_output_config_registry,
)
from ludwig.schema.metadata import FEATURE_METADATA
from ludwig.schema.metadata.parameter_metadata import INTERNAL_ONLY
Expand Down
4 changes: 1 addition & 3 deletions ludwig/trainers/trainer_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def train_step(

# Validate inputs and targets
if not len(inputs) == 1:
raise ValueError(
f"Invalid reward model training data inputs, expect 1 input feature, got {len(inputs)}."
)
raise ValueError(f"Invalid reward model training data inputs, expect 1 input feature, got {len(inputs)}.")
if not len(targets) == 1:
raise ValueError(
f"Invalid reward model training data targets, expect 1 target feature, got {len(targets)}."
Expand Down

0 comments on commit f6c84b1

Please sign in to comment.