Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Optional) state_dict for each transform (reproducibility) #1055

Closed
nicoloesch opened this issue Mar 15, 2023 · 2 comments
Closed

(Optional) state_dict for each transform (reproducibility) #1055

nicoloesch opened this issue Mar 15, 2023 · 2 comments
Labels
enhancement New feature or request

Comments

@nicoloesch
Copy link
Contributor

馃殌 Feature
This feature would allow the saving of each transformation in a reproducible format, that can be retrieved from a saved state. This comes in handy if torchio is utilised in combination with pytorch_lighting, where the state_dict method targets to save the state of the LightningModule (model) and LightningDataModule (data) in a model checkpoint. If one could save the configuration of each transform in an easy way e.g. a dict with all keys, one can easily restore each transformation from a saved state and hereby restore the entire LightningDataModule without any issues. This extends beyond pytorch_lightning but the mechanism is highlighted in this case in an easy understandable way.

Motivation

Continuing training or testing a model from a pre-trained state is common practice in machine learning. In order to support state recovery, each transforms class would require a method state_dict and load_state_dict so it can be restored with the correct parameters utilised in the previous training. The other option is to have both methods in the base class transforms with the associated class name of the transforms object to be stored, from which the correct child class can be instantiated.

Pitch

Save the state of the model including the data module in a recoverable dictionary. This requires the saving of the state of each of the transforms. Also allow the recovery of this state for each transform.

Alternatives
Just instantiate each transform with an associated hparams file, that is utilised to recreate the transforms from scratch without the requirement of saving the state of the transform.

Additional context

Follows the implementation of the LightningDataModule of pytorch_lighting. If one is capable of obtaining the correct state of each of the transforms, one can again compose the initial LightningDataModule.

@nicoloesch nicoloesch added the enhancement New feature or request label Mar 15, 2023
@romainVala
Copy link
Contributor

hello
I am not sure to understand, which exact state you need to save:

Let say you have a RandomAffine transform instantiate with scales=1 degrees=20 translations=0
so each time you have a new data you get a new rotation.

What do you want to save ?
the exacte rotation value of the last transform ? (what for then ?)

If the objective is to start the training from where you stopped, I do not see the point the save the last used transform state (ie rotation value) if you need to continue training you can just used the same RandomTransform you used to start with ... (because it is random ....) no ?

but may be I miss something here

@nicoloesch
Copy link
Contributor Author

Hi,

I think you might be right! I was thinking that the transform could differ between subjects e.g. w.r.t. to in_min_max for some intensity transforms but as long as the initial list of subjects and the SubjectsDataset can be restored (by just initialising the same transforms and the same subjects), I can restore my state.

This feature request can be therefore closed (I think)!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants