In [None]:
# Basic imports and some hyperparameters
import torch
import torch.nn as nn

# Matrices of size 4 x 4 or 4 x 5
N = 4
M = 5
# Batch size of 3
B = 3
x = torch.rand(B, N, N)

# Parametrisations ([PR #33344](https://github.com/pytorch/pytorch/pull/33344))

This notebook provides an introduction to the design of parametrisations in PyTorch. Parametrisations are the way `geotorch` works behind the scenes, so having some grip on how they work should greatly help in using `geotorch` effectively.

## Motivating Example

Given a function `f` and a `Parameter` `X` which is registered on a module, we would like to be able to use `f(X)` in place of `X`.

This is easier understood with an example. Suppose that we want to have a linear layer whose matrix is symmetric. We could write:

In [None]:
class Symmetric(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(n_features, n_features))

    def forward(self, x):
        A = self.weight.triu()
        A = A + A.T
        #print(A)           # A is symmetric
        return x @ A
layer = Symmetric(N);
layer(x);  # It works as expected

This implementation has clearly two components. A reimplmenentation of `nn.Linear` and a parametrisation of the symmetric matrices:

In [None]:
class SymmetricParametrization(nn.Module):
    def forward(X):
        A = X.triu()
        return A + A.T 

## Objective

We would like to separate these two, and have a mechanism to be able to inject a parametrisation onto a parameter or a buffer in a neural network. In particular, we would like to be able to do the following:

```python
layer = nn.Linear(N, N)
torch.register_parametrization(layer, "weight", SymmetricParametrization())
# layer now behaves as an object from the `Symmetric` class
print(layer.weight)  # Prints the symmetric matrix
layer(x)             # Multiplies the vectors `x` by the symmetric matrix layer.weight 
```

## Examples

### Symmetric layers

(see above)

### Pruning
When doing pruning, one samples a boolean mask of the size of the parameter and does an element-wise multiplication. It seems that one may train a neural network and then make it somewhat sparse, and everything magically works. This is called the "lottery ticket hypothesis". (see `torch.nn.utils.prune`)

A simple pruning method that prunes an entry of the tensor with some given probability could go as:

In [None]:
class PruningParametrization(nn.Module):
    def __init__(self, X, p_drop=0.2):
        # sample zeros with probability p_drop
        mask = torch.full_like(X, 1.0 - p_drop)
        self.mask = torch.bernoulli(mask)

    def forward(self, X):
        return X * self.mask

We would like to use it as:
```python
cnn = nn.Conv2D(8, 16, (3, 3))
torch.register_parametrization(cnn, "weight", PruningParametrization(cnn.weight, p_drop=0.1))
# 10% of the entires of the tensor cnn.weight have now been zeroed out
```
### Other examples:
- `torch.weight_norm`
- `torch.spectral_norm` (to regularise the Lipschitz constant of a layer)
- Optimisation with orthogonal constraints / invertible layers / Symmetric Positive Defininite layers... More on this later

## Implementing `torch.register_parametrization`
### A first approximation
A moment's reflection shows that it is possible to implement `Symmetric` without having to reimplement `nn.Linear` by using inheritance and properties.

```python
class SymmetricRevisited(nn.Linear):
    def __init__(self, n_features):
        super().__init__(n_features, n_features, bias=False)
        # Rename weight attribute to _weight
        self._weight = self.weight
        delattr(self, "weight")
    
    @property
    def weight(self):
        A = self._weight.triu()
        return A + A.T
```

Note: This code does not work! It is possible to make it work using metaclasses (for example), but we will skip that.

### A caching system

Sometimes we use the same layer many times in the forward pass of a neural network (e.g., in the recurrent kernel of an RNN). In those cases, we would not want to recompute `layer.weight` every time we execute it. We would like to compute it a the beginning of the forward pass and cache the result throughout the whole forward pass.

We can achieve that by implementing a caching system as follows:

In [None]:
from contextlib import contextmanager
_cache_enabled = 0
_cache = {}

@contextmanager
def cached():
    global _cache
    global _cache_enabled
    _cache_enabled += 1
    try:
        yield
    finally:
        _cache_enabled -= 1
        if not _cache_enabled:
            _cache = {}

class SymmetricCached(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        # Rename weight attribute to _weight
        self._weight = nn.Parameter(torch.rand(n_features, n_features))
        
    def parametrization(self, X):
        print("Computing")
        A = X.triu()
        return A + A.T

    @property
    def weight(self):
        global _cache

        key = (id(self), "weight")
        if key not in _cache:
            _cache[key] = self.parametrization(self._weight)
        return _cache[key]
    
    def forward(self, x):
        return x @ self.weight.T

# Usage:
layer = SymmetricCached(N)
with cached():
    # Just computes the parametrization once
    print(layer.weight - layer.weight.T)


### A generic implementation

Now, all we need to do is to implement a function that, given a module, a name, and a parametrisation (i.e., another module), injects a property similar to how we did it manually in `SymmetricCached`. In particular, we have to write a function with signature
```python
def register_parametrization(module: Module, tensor_name: str, parametrization: Module) -> None:
```
that does:

- Rename the tensor from `tensor_name` to `f"_{tensor_name}"`
- Saves `parametrization` within `module` to use it in the forward pass
- Injects a property with the name `tensor_name` that computes `parametrization(module[tensor_name])` when called

The first two things are direct. To implement the third one, we use the `type` function.

In [None]:
def inject_property(module, tensor_name):
    # We create a new class so that we can inject properties in it
    cls_name = "Parametrized" + module.__class__.__name__

    # Define the getter
    def getter(module):
        global _cache

        key = _key(module, tensor_name)
        # If the _cache is not enabled or the caching was not enabled for this
        # tensor, this function just evaluates the parametrization
        if _cache_enabled and key in _cache:
            if _cache[key] is None:
                _cache[key] = module.parametrizations[tensor_name]()
            return _cache[key]
        else:
            return module.parametrizations[tensor_name]()

    # Define the setter
    def setter(module, value):
        module.parametrizations[tensor_name].initialize(value)
        
    # Create a new class that inherits from `module.__class__` and has a property called `tensor_name`
    param_cls = type(cls_name, (module.__class__,), {
        tensor_name: property(getter, setter)
    })
    module.__class__ = param_cls

layer = nn.Linear(3, 4)
inject_property(layer, "weight")
print(type(layer))
print(type(layer).weight)

## Other things that `torch.register_parametrization` allows:

- If the module implements an `initialize_` method (similar to a right-inverse of forward, more on this below), it allows initialising the parametrised buffer/parameter
- It allows putting several parametrisations on the same buffer/parameter
- It allows removing the parametrisations and leave the original parameter or the parametrised parameter
- Any combination of the above

## More applications of parametrizations

- Constrained optimisation on manifold using `geotorch`!
- Normalising flows. The `initialize_` method can be implemented as a right-inverse of forward. 
    - In the simplest case, if `forward` is a diffeomorphism, then this reduces to the usual normalising flows framework.
    - The general case comes when the forward is a [submersion](https://en.wikipedia.org/wiki/Submersion_(mathematics)) (a function with differentiable local right-inverses). An example of this is a linear layer from `R^n` to `R^k` with `n > k` that is full rank (e.g. a `k x n` matrix with orthogonal rows). Using a submersion, one may construct a generalisation of normalising flows that allows for dimensionality reduction. The simplest case of this setting comes from projecting a vector in `R^n` onto its first `k` compontents. This is called in the normalising flows literature "multi-scale architecture", and it was introduced in the model [real NVP](https://arxiv.org/abs/1605.08803).
    
## Examples of some simple parametrisations, composing them, and initialising them

In [None]:
# This part assumes that you have `geotorch` installed. You can install it doing
# pip install git+https://github.com/Lezcano/geotorch/
import geotorch.parametrize as P

class Skew(nn.Module):
    def forward(self, X):
        X = X.triu(1)
        return X - X.T

    def is_skew(self, X):
        return torch.norm(X + X.T).item() < 1e-5

    def initialize_(self, X):
        if not self.is_skew(X):
            raise ValueError()
        return X.triu(1)
    
# Skew.forward(Skew.initialize_(X)) == X
# In functional notation: Skew.forward o Skew.initialize_ = Id
# In other words, initialize_ is a right inverse of forward.

model = nn.Linear(5, 5)
P.register_parametrization(model, "weight", Skew())
# Just computes `model.weight` once
with P.cached():
    assert(torch.norm(model.weight + model.weight.T) == 0)
# Sample a skew matrix X and initialise the parametrised model.weight
X = torch.rand(5,5)
X = X - X.T
model.weight = X
assert(torch.norm(model.weight - X) < 1e-4)

In [None]:
class Orthogonal(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.register_buffer("B", torch.eye(n))

    def forward(self, A):
        # Cayley map: (I + A)(I - A)^{-1}
        # This is orthogonal whenever A is skew-symmetric
        Id = torch.eye(A.size(0))
        return self.B @ torch.solve(Id - A, Id + A).solution

    def is_orthogonal(self, X):
        Id = torch.eye(X.size(0))
        return torch.norm(X.T @ X - Id) < 1e-4

    def initialize_(self, X):
        if not self.is_orthogonal(X):
            raise ValueError()
        # cayley(0) == Id, so B @ cayley(0) == B
        self.B = X
        return torch.zeros_like(X)


model = nn.Linear(5,5)
P.register_parametrization(model, "weight", Skew())
P.register_parametrization(model, "weight", Orthogonal(5))

# Sample an orthogonal matrix and initialise the layer
X = torch.empty_like(model.weight)
nn.init.orthogonal_(model.weight)
model.weight = X

# model.weight == X
assert(torch.allclose(model.weight, X))

# A more programmatic way of initialising the weight
model.weight = nn.init.orthogonal_(model.weight)