-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add pytorch-lightning decorator to nano #3181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pytorch-lightning decorator to nano #3181
Conversation
python/chronos/test/bigdl/chronos/model/test_pytorch_lightning_wrapper.py
Outdated
Show resolved
Hide resolved
| ): | ||
| r""" | ||
| Create an instance from torch.utils.data.Dataset. | ||
| Override pl.LightningDataModule.from_datasets for cpu usage, setting pin_memory as False by default. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's strange, I wonder why.
| from torch.utils.data import DataLoader, Dataset, IterableDataset | ||
|
|
||
|
|
||
| class LightningModuleWrapper(pl.LightningModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have planned to put LightningModuleWrapper to nano.
|
|
||
|
|
||
| class LightningModuleWrapper(pl.LightningModule): | ||
| def __init__(self, model_creator, configs: dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be better if we just accept three creator.
- model creator
- loss creator
- optim creator
and a config
|
|
python/nano/src/bigdl/nano/pytorch/vision/models/lightning_support.py
Outdated
Show resolved
Hide resolved
python/nano/src/bigdl/nano/pytorch/vision/models/lightning_support.py
Outdated
Show resolved
Hide resolved
python/nano/src/bigdl/nano/pytorch/vision/models/lightning_support.py
Outdated
Show resolved
Hide resolved
| return getattr(torch.optim, config.get("optim", "Adam"))(model.parameters(), lr=config.get("lr", 0.001)) | ||
|
|
||
|
|
||
| @lightning_support.lightning_module(loss_creator, optimizer_creator, config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is too complex for the user; I think we can simply do something like:
class MyModel(torch.nn.moudule):
...
model = MyModel(...)
loss = nn.CrossEntropyLoss()
opt= optim.Adam(...)
trainer.fit(model, loss, opt, train_data, val_data)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a user (let's say chronos) have a nn.module, and he/she wants to use the training opt in our bigdl.nano trainer, and onnxruntime inference (which is a support on pl lightning module). He can't use this api design to do both of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a user (let's say chronos) have a nn.module, and he/she wants to use the training opt in our bigdl.nano trainer, and onnxruntime inference (which is a support on pl lightning module). He can't use this api design to do both of them.
You may be able to do something like:
loss = nn.CrossEntropyLoss()
opt= optim.Adam(...)
@pl_module(loss, opt)
class MyModel(torch.nn.moudule):
...
model = MyModel(...)
trainer.fit(model, train_data, val_data)But how do you plan to add onnxruntime support in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will do it like this
loss = nn.CrossEntropyLoss()
opt= optim.Adam(...)
@onnxruntime()
@pl_module(loss, opt)
class MyModel(torch.nn.moudule):
...
model = MyModel(...)
trainer.fit(model, train_data, val_data)as I stated in #3272
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I think we may make this PR only focus the decorator support just to make it easier to review and clearer. And we will raise another PR for the trainer API's change.
Which one do you think will have better user experience?
loss = nn.CrossEntropyLoss()
opt= optim.Adam(...)
@onnxruntime()
@pl_module(loss, opt)
class MyModel(torch.nn.moudule):
...
model = MyModel(...)or
class MyModel(torch.nn.moudule):
...
model = MyModel(...)
loss = nn.CrossEntropyLoss()
opt= optim.Adam(...)
nano_model = trainer.compile(model, loss, opt, onnx=True)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that the trainer.compile has a better user experience since you can not create a opt instance before the model instance is built, which leads to some config dict in the decorator case.
Still the method involve a abnormal usage (i.e. trainer.compile). We need to give thorough and detailed user guide and in-code-warning.
btw, we don't need to implement onnx=True parameter in this PR since I have not merged the onnx PR. I will do it later after this PR has been merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that the
trainer.compilehas a better user experience since you can not create a opt instance before the model instance is built, which leads to some config dict in the decorator case.Still the method involve a abnormal usage (i.e.
trainer.compile). We need to give thorough and detailed user guide and in-code-warning.btw, we don't need to implement
onnx=Trueparameter in this PR since I have not merged the onnx PR. I will do it later after this PR has been merged.
What do you mean by abnormal usage? I think the point is
-
If the user simply has a PyTorch model, he or she can directly use it in
nano.pytorch.trainermethods (if the default behavior offit,test,predict, etc., works for him or her) -
If the user needs more complex behavior (e.g., onnxruntime support), he or she needs to explictly convert it to pl_module, and we can provide an API based on either decorator or compile method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that the
trainer.compilehas a better user experience since you can not create a opt instance before the model instance is built, which leads to some config dict in the decorator case.
Still the method involve a abnormal usage (i.e.trainer.compile). We need to give thorough and detailed user guide and in-code-warning.
btw, we don't need to implementonnx=Trueparameter in this PR since I have not merged the onnx PR. I will do it later after this PR has been merged.What do you mean by abnormal usage? I think the point is
- If the user simply has a PyTorch model, he or she can directly use it in
nano.pytorch.trainermethods (if the default behavior offit,test,predict, etc., works for him or her)- If the user needs more complex behavior (e.g., onnxruntime support), he or she needs to explictly convert it to pl_module, and we can provide an API based on either decorator or compile method.
I think we can just ask all users with a nn.module to do this:
model = Net()
model_pl = trainer.compile(model, loss, opt, onnx=bool)
# then use model_pl (pl.lightningmodule) to do anything else (e.g. `fit`, `test`, `predict` with trainer)abnormal usage means that trainer.compile is not a easy name for users to come up with directly, we need examples, quickstarts and user guides to guide them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not have to be trainer.compile; maybe something similar to ray.distribtued, such as nano.pl_module
|
Modifications:
|
|
I don't think we should implementation multiple interfaces for the same use case - it just confuses the users; instead, just provide one single interface for one use case. |
Sure, we need to decide which usage is much more common and friendly to users. And at the same time onnxruntime support should be aligned if we choose to wrap a torch model instance instead of a torch class. @to_lightning(loss, torch.optim.Adam, lr=0.01)
class Net(nn.Module):
pass
pl_model = Net()or class Net(nn.Module):
pass
pl_model = to_lightning(loss, torch.optim.Adam, lr=0.01)(Net())I suppose we may have more extensions coming on pytorchlightning model. Maybe we can wrap all these extensions into one so user don't have to decide which function to use: def composed(*decs):
def deco(f):
for dec in reversed(decs):
f = dec(f)
return f
return deco
def preprocess(loss=None, optimizer=None, config=None, onnx=True):
return composed(
onnxruntime(onnx),
to_lightning(loss, optimizer, **config)
)
pl_model = preprocess(loss, torch.optim.Adam, {"lr": 0.01}, onnx=True)(model)We can also integrate this function into |
Need a consistent API for all these use cases. |
|
Having discussed with @TheaperDeng @yangw1234 , currently we have 3 solutions:
We prefer option 2 to make the conversion as a fixed process and giving proper warning and error to inform users of the correct usage. Finally user need to do: The name of this @jason-dai What's your thoughts? |
I think we only need |
The main reason why we want to make pl_model = trainer.compile(torch.vision.ResNet18(), loss, optimizer, ...)
pl_model = trainer.compile(LightningModule(...)) # This does nothing and returns LightningModule(...)Above are all legal for users to use so they don't have to change their lightning code. By this way, they can be warned to use |
It's OK to use |
|
As discussed above, make modifications:
|
| @staticmethod | ||
| def compile(model: nn.Module, loss: _Loss = None, optimizer: torch.optim = None): | ||
| """ | ||
| Compile a pytorch model into a pytorch-lightning model and return it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the comment is incorrect if we also support LightningModule below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please review modified docstring. Any further suggestions to properly describe this function?
|
Passed tests. It's ready now. |
* added decorator to create a pytorch lightning model from torch * added unit test for pytorch lightning decorator * refactoring - renaming, adding hints and docstring * moved lightning extension to nano/pytorch * remove loss, optim creator and directly pass loss and optimizer to initiate * added another implementation for pytorch to lightning * use LightningModuleFromTorch to create lightning module from pytorch * remove temporary change * remove redundant part * added trainer.compile to convert pytorch to pytorch-lightning * added unit test for trainer.compile * fixed return when input is pl model * added type hint for LightningModuleFromTorch.copy * Renamed copy as _copy * Modified comment of compile * added input checking * refactored docstring * Reformat docstring * Tiny changes * reformat * correct the import * type check and * assign model as a member variable * override load_state_dict * fix test_trainer_compile * fix test_lightning * try lightning module and then self.model * rename _forward as forward * type check * optimize imports Co-authored-by: Yang Wang <yang3.wang@intel.com>
Resolved 1st task in #3171.
Trainer.compilefor user to use.