Skip to content

Commit

Permalink
Modify trainer, tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
asdataminer committed Jun 15, 2023
1 parent 4860265 commit 9cf18d6
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
39 changes: 35 additions & 4 deletions ludwig/trainers/trainer_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ludwig.schema.trainer import RWDTrainerConfig
from ludwig.trainers.registry import register_trainer
from ludwig.trainers.trainer import Trainer
from ludwig.utils.batch_size_tuner import BatchSizeEvaluator
from ludwig.utils.defaults import default_random_seed

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -56,6 +57,30 @@ def __init__(
# Save the reward model loss function
self.reward_loss_function = RewardLoss({})

def _create_batch_size_evaluator(self) -> BatchSizeEvaluator:
trainer = self

class _TrainerBatchSizeEvaluator(BatchSizeEvaluator):
def reset(self):
trainer.model.reset_metrics()
trainer.optimizer.zero_grad()

def step(self, batch_size: int):
trainer.distributed.set_batch_size(trainer.dist_model, batch_size)
inputs = {
input_feature_name: [
input_feature.create_sample_input(batch_size=batch_size).to(trainer.device),
input_feature.create_sample_input(batch_size=batch_size).to(trainer.device),
] for input_feature_name, input_feature in trainer.model.input_features.items()
}
targets = {
output_feature_name: output_feature.create_sample_output(batch_size=batch_size).to(trainer.device)
for output_feature_name, output_feature in trainer.model.output_features.items()
}
trainer.train_step(inputs, targets)

return _TrainerBatchSizeEvaluator()

def train_step(
self, inputs: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], should_step: bool = True
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
Expand Down Expand Up @@ -84,15 +109,21 @@ def train_step(
raise ValueError(
f"Invalid reward model training data targets, expect 1 target feature, got {len(targets)}."
)
id_column = list(targets.keys())[0]
transcript_column = list(inputs.keys())[0]

# Run forward-propagation of the chosen and rejected inputs
with self.distributed.prepare_model_update(self.dist_model, should_step=should_step):
# Obtain model predictions and loss
model_output_chosen = self.dist_model(inputs[self.transcript_column][0])
model_output_rejected = self.dist_model(inputs[self.transcript_column][1])
loss = self.reward_loss_function(model_output_chosen, model_output_rejected)
inputs_chosen = {transcript_column: inputs[transcript_column][0]}
inputs_rejected = {transcript_column: inputs[transcript_column][1]}
model_output_chosen = self.dist_model(inputs_chosen)
model_output_rejected = self.dist_model(inputs_rejected)
logits_chosen = model_output_chosen[f"{id_column}::logits"]
logits_rejected = model_output_rejected[f"{id_column}::logits"]
loss = self.reward_loss_function(logits_chosen, logits_rejected)
loss = loss / self.gradient_accumulation_steps
all_losses = loss
all_losses = {"reward_loss": loss}

# Begin the backward pass
variables = self.dist_model.parameters()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_rlhf_reward_model_data_preprocessor():
"transcript_column": transcript_column,
}
}
config[TRAINER] = {"type": "reward_model"}
config["model_type"] = "rwd"

# Run preprocessing, get output dataset
ludwig_model = LudwigModel(config, backend=backend)
Expand Down
5 changes: 5 additions & 0 deletions tests/integration_tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ def test_rlhf_reward_model_trainer(tmpdir):
"transcript_column": transcript_column,
}
}
config[TRAINER] = {
"epochs": 2,
BATCH_SIZE: 4,
"learning_rate": 1.0,
}
config["model_type"] = "rwd"

# Train Ludwig model with the dataset
Expand Down

0 comments on commit 9cf18d6

Please sign in to comment.