Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parametrization for normalization #20

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from

Conversation

y-prudent
Copy link
Member

@y-prudent y-prudent commented Nov 29, 2023

Prerequisite: Torch parametrization tutorial

Features

  • Rewrite bjorck, frobenius, and lconv normalizations so that they use torch parametrization instead of forward pre hooks.
  • Faster inference in the model.eval() mode (when model.training is set to False, spectral and bjorck normalizations use cached tensors to perform free normalizations).
  • Vanilla model conversion can be done on any model (not only Sequential) with vanilla_model. Be careful, this is an in-place conversion!
  • Torch parametrize.cached() feature is now also usable on Lipschitz layers, allowing to save memory and compute when the same kernel is applied multiple times in an inference step (very useful for RNNs, multi-level convolutions, etc.). Here is how to use it:
import torch.nn.utils.parametrize as P

with P.cached():
    y1 = lip_layer(x1)    # at the first call, compute normalization and save the reparametrized weights
    y2 = lip_layer(x2)    # at the second call, reuse saved reparametrized weights

⚠️ Important note:

Models using parametrized modules can only be serialized through state_dict(). So torch.save(model, PATH) is not possible anymore and will raise an error. Instead, save and load your models like this:

# save model
torch.save(
    {
        'model_state_dict': model.state_dict(),
        'model_kwargs': config,  # arguments used to build the model
    }, PATH
)

# load model
checkpoint = torch.load(PATH)
model = TheModelClass(**checkpoint["model_kwargs"])
model.load_state_dict(checkpoint['model_state_dict'])

For more information, check this torch tutorial.

TODO

  • Rewrite bjorck, frobenius and lconv normalizations with parametrization
  • Manage the suppression of individual parametrizations (torch native functions only allow the deletion of all the parametrizations at once)
  • Manage the parametrization of lconv_norm that depends on the input shape (=> solution: forward pre hooks + parametrization)
  • Also parametrize the global lipschitz coef multiplication
  • Write associated tests
  • Remove deprecated hook relative files
  • Check that vanilla_export works well
  • Check parametrize.cached() feature

@y-prudent y-prudent force-pushed the feat-parametrization-for-normalization branch from 7ee8f09 to a3b262f Compare November 29, 2023 18:04
@y-prudent y-prudent force-pushed the feat-parametrization-for-normalization branch from a3b262f to 80a75a0 Compare November 29, 2023 18:47
@y-prudent y-prudent marked this pull request as ready for review November 30, 2023 12:58
Copy link
Collaborator

@cofri cofri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting PR!
Just a small suggestion for vanilla_export import

@@ -8,7 +8,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [3.6, 3.7, 3.8]
python-version: [3.7, 3.8, 3.9]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the statuses of Python versions, Python 3.7 is already deprecated. Maybe we can use a more powerful test matrix with more recent Python versions (and also PyTorch versions?).
It's a more important change and it could be postponed to a future PR

deel/torchlip/modules/module.py Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants