-
Notifications
You must be signed in to change notification settings - Fork 1
Full rewrite #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Full rewrite #14
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…ompile by redirecting compiled models to module
…ce compilation in raw models in a distributed system causes many errors with TorchScript
…evaluation instead of the very first one
…ument for log_epoch
…wargs for torch.compile customizations. Also added 'safe_mode' parameter (default is True) to ensure gradient syncronization and mixed precision settings handled by wrappers.
…SDP and DeepSpeed)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is an official full-rewrite of the library, getting ready for 2.0 update. This rewrite includes many bug fixes along with new features. Some of the user's old code might change by just a bit, although previous training runs cannot run with this new update 😢 (unless you do some updates to your checkpoint folder, which might not be efficient).
REPLACEMENTS AND REMOVALS:
model_savingparameter in Trainer class no longer exists. It was replaced bytrainer.register_model_savingfunction.model_saving_belowandmodel_saving_aboveparameters in Trainer were removed. Now they exist inregister_model_saving.optimparameter no longer exists in object (or dictionary) of HyperParameters. It was replaced byoptimizer.collate_fnparamater in Trainer no longer exists. Available options arecollate_fn_trainandcollate_fn_val.checkpointparameter in Trainer no longer exists. It was replaced by a boolean parameterenable_checkpointing(defaults to True).validation_stephas changed fromdef validation_step(self, batch)todef validation_step(self, key, batch).status_dictno longer exists. It was replaced bystatewhich is a class containing the previous and new parameters to track training state.report_train_loss_per_epochwas removed and its functionality is handled bylog_every. If this parameter is set to a value less than 0 (-1), it will report train loss at the end of the epoch.handlersparameter in Trainer was removed since it was producing a lot of errors and crashes.shuffle_validationparameter in Trainer was removed since it does not make sense to shuffle a validation dataset.NEW FEATURES:
validation_stepfunction.compile_kwargs. This is a dictionary with additional kwargs for torch.compile.loopfunction in Trainer can be modified by inheritance!. If you want to add more customization to your training loop, you can create another trainer class inheriting from Trainer to modify the loop.callbackparameter in Trainer can now contain multiple callbacks!.disable_model_savingparameter in Trainer.safe_modeparameter in Trainer. Running in safe mode (default) means that forward passes will be done using the corresponding wrapper (DDP, FSDP or DeepSpeedEngine). Ifset_mode=False, this means that wrapper will be skipped and use the model directly. This slightly improves throughput, although it is unsure that gradients will be correctly synchronized across all devices.metricsparameter in Trainer can be a dictionary, where keys are the name of dataset keys and values are the metrics to implement for that particular evaluation dataset. Basically, you can now have different metrics per dataset.BUG FIXES:
clip_gradparameter in Trainer was not working in DeepSpeed, because the configuration file sets an automatic value of 1.0 (default). We changed this behavior to always specify gradient clipping throughclip_gradparameter. The default gradient clipping value, independent of the strategy applied, will always be 0.grad_accumulation_stepswas not being correctly handled and led to incorrect results.patiencewas not being correctly handled. If this value was set higher than 0, all model savings, even if they were better than previous results, will end up finishing since patience was always being reduced every time it was time to save the model (after evaluation).eval_when_start=Trueno longer saves the model or checkpoint, because there is no point of saving the model when there is no progress at all.eval_when_finish=Trueno longer occurs twice whenever another evaluation was done in the last step of an epoch or at the end of an epoch.set_seedfunction this will save a global seed to access it afterwards in a new epoch, so the seeds will be set to GLOBAL_SEED + EPOCH.