Skip to content

Commit

Permalink
Bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
asdataminer committed Jun 20, 2023
1 parent 9cf18d6 commit 82bbffe
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 10 deletions.
4 changes: 2 additions & 2 deletions ludwig/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,7 +1323,7 @@ def build_dataset(
proc_cols[proc_column] = backend.df_engine.map_objects(proc_cols[proc_column], lambda x: x.reshape(-1))

# If training a reward model, prepare the processed columns
if mode == "training" and "reward_dataset" in global_preprocessing_parameters:
if mode == "training" and global_preprocessing_parameters["reward_dataset"] is not None:
reward_parameter_names = [
"id_column",
"outcome_column",
Expand Down Expand Up @@ -1418,7 +1418,7 @@ def build_dataset(
dataset = embed_fixed_features(dataset, feature_configs, metadata, backend)

# If training a reward model, perform grouping and joining on dataset
if mode == "training" and "reward_dataset" in global_preprocessing_parameters:
if mode == "training" and global_preprocessing_parameters["reward_dataset"] is not None:

def parse_id_rows_group(rows_group):
rows_idxs = rows_group.index
Expand Down
10 changes: 3 additions & 7 deletions ludwig/trainers/trainer_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def step(self, batch_size: int):
input_feature_name: [
input_feature.create_sample_input(batch_size=batch_size).to(trainer.device),
input_feature.create_sample_input(batch_size=batch_size).to(trainer.device),
] for input_feature_name, input_feature in trainer.model.input_features.items()
]
for input_feature_name, input_feature in trainer.model.input_features.items()
}
targets = {
output_feature_name: output_feature.create_sample_output(batch_size=batch_size).to(trainer.device)
Expand All @@ -94,12 +95,7 @@ def train_step(
Returns:
A tuple of the loss tensor and a dictionary of loss for every output feature.
"""
if not all(
[
self.use_amp is False,
self.evaluate_training_set is True,
]
):
if self.use_amp is True:
raise ValueError("Invalid trainer arguments for RLHF reward model")

# Validate inputs and targets
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_rlhf_reward_model_trainer(tmpdir):
input_features = [
text_feature(
name=transcript_column,
encoder={"type": "auto_transformer", "pretrained_model_name_or_path": "bert-base-uncased"},
encoder={"type": "auto_transformer", "pretrained_model_name_or_path": "gpt2", "trainable": True},
)
]
output_features = [number_feature(name=id_column)]
Expand Down

0 comments on commit 82bbffe

Please sign in to comment.