Skip to content

Conversation

@ghanvert
Copy link
Owner

@ghanvert ghanvert commented Apr 1, 2025

This is an official full-rewrite of the library, getting ready for 2.0 update. This rewrite includes many bug fixes along with new features. Some of the user's old code might change by just a bit, although previous training runs cannot run with this new update 😢 (unless you do some updates to your checkpoint folder, which might not be efficient).

REPLACEMENTS AND REMOVALS:

  • model_saving parameter in Trainer class no longer exists. It was replaced by trainer.register_model_saving function.
  • model_saving_below and model_saving_above parameters in Trainer were removed. Now they exist in register_model_saving.
  • optim parameter no longer exists in object (or dictionary) of HyperParameters. It was replaced by optimizer.
  • collate_fn paramater in Trainer no longer exists. Available options are collate_fn_train and collate_fn_val.
  • checkpoint parameter in Trainer no longer exists. It was replaced by a boolean parameter enable_checkpointing (defaults to True).
  • Declaration of validation_step has changed from def validation_step(self, batch) to def validation_step(self, key, batch).
  • Internal status_dict no longer exists. It was replaced by state which is a class containing the previous and new parameters to track training state.
  • report_train_loss_per_epoch was removed and its functionality is handled by log_every. If this parameter is set to a value less than 0 (-1), it will report train loss at the end of the epoch.
  • handlers parameter in Trainer was removed since it was producing a lot of errors and crashes.
  • shuffle_validation parameter in Trainer was removed since it does not make sense to shuffle a validation dataset.

NEW FEATURES:

  • Multiple evaluations supported! Now you can pass a list or dictionary of evaluation datasets. Each dataset will have its corresponding key that can be accessed in the validation_step function.
  • Model saving with additional syntax to save best models based on metrics of specific datasets and best metric values of different metrics.
  • New parameter in Trainer compile_kwargs. This is a dictionary with additional kwargs for torch.compile.
  • A new better looking progress bar!. This is actually the same library (tqdm), but with colors and less size. Also this new feature removes some weird visual bugs that were overlapping training and validation progress bars.
  • loop function in Trainer can be modified by inheritance!. If you want to add more customization to your training loop, you can create another trainer class inheriting from Trainer to modify the loop.
  • Better code, better throughput!. Since this is an almost complete rewrite of the library, every decision in the code was done to optimize throughput and make the code more readable than before for better maintainability.
  • callback parameter in Trainer can now contain multiple callbacks!.
  • New disable_model_saving parameter in Trainer.
  • New safe_mode parameter in Trainer. Running in safe mode (default) means that forward passes will be done using the corresponding wrapper (DDP, FSDP or DeepSpeedEngine). If set_mode=False, this means that wrapper will be skipped and use the model directly. This slightly improves throughput, although it is unsure that gradients will be correctly synchronized across all devices.
  • With the new addition of multiple evaluations, metrics parameter in Trainer can be a dictionary, where keys are the name of dataset keys and values are the metrics to implement for that particular evaluation dataset. Basically, you can now have different metrics per dataset.

BUG FIXES:

  • clip_grad parameter in Trainer was not working in DeepSpeed, because the configuration file sets an automatic value of 1.0 (default). We changed this behavior to always specify gradient clipping through clip_grad parameter. The default gradient clipping value, independent of the strategy applied, will always be 0.
  • grad_accumulation_steps was not being correctly handled and led to incorrect results.
  • patience was not being correctly handled. If this value was set higher than 0, all model savings, even if they were better than previous results, will end up finishing since patience was always being reduced every time it was time to save the model (after evaluation).
  • First evaluation set with eval_when_start=True no longer saves the model or checkpoint, because there is no point of saving the model when there is no progress at all.
  • Last evaluation set with eval_when_finish=True no longer occurs twice whenever another evaluation was done in the last step of an epoch or at the end of an epoch.
  • Resuming from checkpoint was always setting global seeds depending on the epoch (0, 1, 2, etc), independently if the user set a seed already. To mitigate this, when setting a seed with set_seed function this will save a global seed to access it afterwards in a new epoch, so the seeds will be set to GLOBAL_SEED + EPOCH.
  • Resuming from checkpoint and logging was resulting in wrong results, both train loss report and current step number (at least on the first log produced after resuming).

ghanvert added 24 commits March 23, 2025 04:39
…ompile by redirecting compiled models to module
…ce compilation in raw models in a distributed system causes many errors with TorchScript
…wargs for torch.compile customizations. Also added 'safe_mode' parameter (default is True) to ensure gradient syncronization and mixed precision settings handled by wrappers.
@ghanvert ghanvert merged commit 60263b6 into main Apr 2, 2025
1 check passed
@ghanvert ghanvert deleted the full-rewrite branch April 2, 2025 16:01
@ghanvert ghanvert mentioned this pull request Apr 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants