Skip to content

Commit

Permalink
Further reduce memory consumption [release]
Browse files Browse the repository at this point in the history
  • Loading branch information
AjayP13 committed Apr 30, 2024
1 parent df473ac commit ca14270
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/trainers/train_hf_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..datasets import OutputDatasetColumn, OutputIterableDatasetColumn
from ..llms.llm import _check_temperature_and_top_p
from ..utils.arg_utils import AUTO, Default, default_to
from ..utils.distributed_utils import set_current_accelerator
from ..utils.fs_utils import mkdir
from ..utils.hf_model_utils import is_peft_model
from ..utils.hf_training_utils import (
Expand Down Expand Up @@ -116,6 +117,7 @@ def compute_metrics(eval_pred):
self.accelerator.prepare_optimizer = (
lambda optimizer, *args, **kwargs: optimizer
)
set_current_accelerator(self.accelerator)

def get_train_dataloader(self) -> DataLoader:
# PPOTrainer's .step() method does not allow smaller than batch size inputs
Expand Down

0 comments on commit ca14270

Please sign in to comment.