Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][Training] Trainer API #115

Merged
merged 15 commits into from
Feb 10, 2023

Conversation

SiriusNEO
Copy link
Contributor

@SiriusNEO SiriusNEO commented Jan 30, 2023

This PR brings a wrapper for relax training. The following things are done internally in this trainer:

  • Maintain (store/update) the parameters of the module.
  • Merge backbone and specified loss function together.
  • Build/Compile/Run the module.
  • Build/Compile/Run the optimizer. (using the same vm_config as we run the module.)

And it also provides two interfaces for loading params/exporting params.

Example:

trainer = Trainer(MLP, [1, 2], "main") # [1, 2] means input[1] and input[2] are parameters in this module.
trainer.set_loss(MSELoss(reduction="sum"), pred_sinfo, pred_sinfo)
trainer.set_vm_config(target="llvm")
trainer.set_optimizer(optim_type=SGD, lr=0.001).setup()
trainer.setup()
trainer.rand_init_params()
trainer.forward(*fwd_inputs)
trainer.backward(*bwd_inputs)

@SiriusNEO SiriusNEO marked this pull request as ready for review February 1, 2023 01:49
@SiriusNEO SiriusNEO changed the title [WIP][Relax][Training] Trainer API [Relax][Training] Trainer API Feb 1, 2023
@SiriusNEO SiriusNEO marked this pull request as draft February 1, 2023 15:42
@SiriusNEO SiriusNEO marked this pull request as ready for review February 4, 2023 11:10
python/tvm/relax/training/trainer.py Outdated Show resolved Hide resolved
python/tvm/relax/training/utils.py Outdated Show resolved Hide resolved
python/tvm/relax/training/setup_trainer.py Outdated Show resolved Hide resolved
python/tvm/relax/training/trainer.py Outdated Show resolved Hide resolved
python/tvm/relax/training/setup_trainer.py Outdated Show resolved Hide resolved
python/tvm/relax/training/trainer.py Outdated Show resolved Hide resolved
python/tvm/relax/training/setup_trainer.py Outdated Show resolved Hide resolved
MasterJH5574 pushed a commit that referenced this pull request Feb 8, 2023
* Add gpu ci.

* Update autotir gpu test.
spectrometerHBH pushed a commit to spectrometerHBH/relax that referenced this pull request Feb 9, 2023
* Add gpu ci.

* Update autotir gpu test.
Copy link
Member

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only one minor point. Besides, would you like to update the PR description a bit so that it is consistent with the new API?

python/tvm/relax/training/setup_trainer.py Outdated Show resolved Hide resolved
@SiriusNEO SiriusNEO merged commit a65d808 into mlc-ai:relax Feb 10, 2023
MasterJH5574 pushed a commit that referenced this pull request Feb 12, 2023
This PR brings a wrapper for relax training. The following things are
done internally in this trainer:
- Maintain (store/update) the parameters of the module.
- Merge backbone and specified loss function together.
- Build/Compile/Run the module.
- Build/Compile/Run the optimizer. (using the same vm_config as we run
the module.)

And it also provides two interfaces for loading params/exporting params.

Example:
```
trainer = Trainer(MLP, [1, 2], "main") # [1, 2] means input[1] and input[2] are parameters in this module.
trainer.set_loss(MSELoss(reduction="sum"), pred_sinfo, pred_sinfo)
trainer.set_vm_config(target="llvm")
trainer.set_optimizer(optim_type=SGD, lr=0.001).setup()
trainer.setup()
trainer.rand_init_params()
trainer.forward(*fwd_inputs)
trainer.backward(*bwd_inputs)
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants