Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 14, 2023
1 parent dd675d6 commit b5d61b9
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
33 changes: 18 additions & 15 deletions ludwig/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,14 +1344,16 @@ def build_dataset(
transcript_column = global_preprocessing_parameters["reward_dataset"]["transcript_column"]

# Validate the input configuration
if not all([
len(config["input_features"]) == 1,
len(config["output_features"]) == 1,
config["input_features"][0]["name"] == transcript_column,
config["input_features"][0]["type"] == "text",
config["output_features"][0]["name"] == id_column,
config["output_features"][0]["type"] == "number",
]):
if not all(
[
len(config["input_features"]) == 1,
len(config["output_features"]) == 1,
config["input_features"][0]["name"] == transcript_column,
config["input_features"][0]["type"] == "text",
config["output_features"][0]["name"] == id_column,
config["output_features"][0]["type"] == "number",
]
):
raise ValueError(f"Invalid reward model training configuration, received {config}.")

# Validate the input dataframe's columns
Expand Down Expand Up @@ -1417,6 +1419,7 @@ def build_dataset(

# If training a reward model, perform grouping and joining on dataset
if mode == "training" and "reward_dataset" in global_preprocessing_parameters:

def parse_id_rows_group(rows_group):
rows_idxs = rows_group.index

Expand All @@ -1428,10 +1431,12 @@ def parse_id_rows_group(rows_group):
)
outcome_first = dataset.loc[rows_idxs[0]][outcome_column]
outcome_second = dataset.loc[rows_idxs[1]][outcome_column]
if not any([
outcome_first == chosen_value and outcome_second == rejected_value,
outcome_first == rejected_value and outcome_second == chosen_value,
]):
if not any(
[
outcome_first == chosen_value and outcome_second == rejected_value,
outcome_first == rejected_value and outcome_second == chosen_value,
]
):
raise ValueError(
f"Incorrect labeling of the 2 text rows for session ID {rows_group.name} when processing "
f"the reward model training dataset: expect one row to be labeled as {chosen_value}, "
Expand All @@ -1446,9 +1451,7 @@ def parse_id_rows_group(rows_group):

# Group dataset rows by ID, aggregate group data
dataset_id_groups = dataset.groupby(id_column)
dataset_refactored = (
dataset_id_groups[transcript_column].apply(parse_id_rows_group).reset_index()
)
dataset_refactored = dataset_id_groups[transcript_column].apply(parse_id_rows_group).reset_index()
if "split" in dataset.columns:
dataset_refactored["split"] = dataset_id_groups["split"].apply(lambda x: list(x)[0]).reset_index()["split"]
dataset = dataset_refactored
Expand Down
8 changes: 5 additions & 3 deletions tests/integration_tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,11 @@ def test_rlhf_reward_model_data_preprocessor():
transcript_column = "Transcript"

# Define the features
input_features = [text_feature(
name=transcript_column,
encoder={"type": "auto_transformer", "pretrained_model_name_or_path": "gpt2"})]
input_features = [
text_feature(
name=transcript_column, encoder={"type": "auto_transformer", "pretrained_model_name_or_path": "gpt2"}
)
]
output_features = [number_feature(name=id_column)]
backend = LocalTestBackend()
config = {"input_features": input_features, "output_features": output_features}
Expand Down
8 changes: 5 additions & 3 deletions tests/integration_tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,11 @@ def test_rlhf_reward_model_trainer(tmpdir):
transcript_column = "Transcript"

# Define the features
input_features = [text_feature(
name=transcript_column,
encoder={"type": "auto_transformer", "pretrained_model_name_or_path": "gpt2"})]
input_features = [
text_feature(
name=transcript_column, encoder={"type": "auto_transformer", "pretrained_model_name_or_path": "gpt2"}
)
]
output_features = [number_feature(name=id_column)]
backend = LocalTestBackend()
config = {"input_features": input_features, "output_features": output_features}
Expand Down

0 comments on commit b5d61b9

Please sign in to comment.