Skip to content

Commit

Permalink
Style edits
Browse files Browse the repository at this point in the history
  • Loading branch information
asdataminer committed Jun 15, 2023
1 parent a93a1b8 commit e311741
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 10 deletions.
9 changes: 5 additions & 4 deletions ludwig/models/ecd.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,11 @@ def get_augmentation_pipelines(self) -> AugmentationPipelines:


class RWD(ECD):
"""
This class represents a Reward Model, a model that inputs some feature (i.e. text transcript) and predicts
a single scalar output representing the reward/preference of that input. This model type is used for applications
such as RLHF fine-tuning of LLMs. This model class is a subclass of ECD, and uses most of ECD's code and pathways.
"""This class represents a Reward Model, a model type that takes as input some feature (i.e. text) and predicts
a single scalar output representing the reward/preference of that input.
This model type is used for applications such as RLHF fine-tuning of LLMs. This model class is a subclass of ECD,
and uses most of ECD's code and pathways.
"""
@staticmethod
def type() -> str:
Expand Down
11 changes: 10 additions & 1 deletion ludwig/schema/decoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import (
BINARY, CATEGORY, MODEL_ECD, MODEL_GBM, MODEL_LLM, MODEL_RWD, NUMBER, SET, TIMESERIES, VECTOR)
BINARY,
CATEGORY,
MODEL_ECD,
MODEL_GBM,
MODEL_LLM,
MODEL_RWD,
NUMBER,
SET,
TIMESERIES,
VECTOR)
from ludwig.schema import common_fields
from ludwig.schema import utils as schema_utils
from ludwig.schema.decoders.utils import register_decoder_config
Expand Down
2 changes: 1 addition & 1 deletion ludwig/schema/features/number_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
ecd_defaults_config_registry,
ecd_input_config_registry,
ecd_output_config_registry,
rwd_output_config_registry,
gbm_defaults_config_registry,
gbm_input_config_registry,
gbm_output_config_registry,
input_mixin_registry,
output_mixin_registry,
rwd_output_config_registry,
)
from ludwig.schema.metadata import FEATURE_METADATA
from ludwig.schema.metadata.parameter_metadata import INTERNAL_ONLY
Expand Down
8 changes: 4 additions & 4 deletions ludwig/trainers/trainer_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


@register_trainer(MODEL_RWD)
class RewardModelTrainer(Trainer):
class RWDTrainer(Trainer):
"""This class trains models of type Reward Model."""

@staticmethod
def get_schema_cls():
return RWDTrainerConfig
Expand Down Expand Up @@ -77,9 +79,7 @@ def train_step(

# Validate inputs and targets
if not len(inputs) == 1:
raise ValueError(
f"Invalid reward model training data inputs, expect 1 input feature, got {len(inputs)}."
)
raise ValueError(f"Invalid reward model training data inputs, expect 1 input feature, got {len(inputs)}.")
if not len(targets) == 1:
raise ValueError(
f"Invalid reward model training data targets, expect 1 target feature, got {len(targets)}."
Expand Down

0 comments on commit e311741

Please sign in to comment.