Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature request] ModuleTrainer class #7

Closed
ncullen93 opened this issue Apr 28, 2017 · 6 comments
Closed

[feature request] ModuleTrainer class #7

ncullen93 opened this issue Apr 28, 2017 · 6 comments

Comments

@ncullen93
Copy link
Member

ncullen93 commented Apr 28, 2017

Could potentially add a ModuleTrainer or ModelTrainer class that works similarly to SuperModule but can take in one or more normal nn.Module classes.. This would allow for support of pre-trained networks and more flexible training structures, while also allowing seamless integration with all other pytorch code

@recastrodiaz

@recastrodiaz
Copy link
Contributor

Sounds like a great idea! ;)
I'll get back to you with some questions Monday or sooner if possible.

@ncullen93
Copy link
Member Author

ncullen93 commented Apr 29, 2017

check out what I just put in modules/model_trainer.py . Added initializations as well.

@ncullen93 ncullen93 reopened this Apr 29, 2017
@recastrodiaz
Copy link
Contributor

Amazing! This is definitely what I had in mind! I saw that you added support for multiple forwards inputs*. That's something I was really looking for!

The History callback class seems to be looking for non existing attributes in the model. Maybe the History constructor should accept a ModuleTrainer as init parameter? (Although this sounds like a circular dependency...). I've removed it from the callbacks from now in ModuleTrainer __init__ function:

self._callbacks = [self.history]

#-->

self._callbacks = []
History class:
/Users/rodrigo/Libs/torchsample/torchsample/callbacks.py in on_train_begin(self, logs)
    168     def on_train_begin(self, logs=None):
    169         self.losses = []
--> 170         if self.model._has_regularizers:
    171             self.regularizer_losses = []
    172         if self.model._has_lagrangian_constraints:

Thanks again for the excellent work!

@ncullen93
Copy link
Member Author

should be fixed can you check?

@ncullen93
Copy link
Member Author

closed bc ModuleTrainer was added

@recastrodiaz
Copy link
Contributor

Can confirm ModuleTrainer works!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants