In [1]:
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

In [3]:
# assume we want a square linear layer with symmetric weights (X = X^T)
# only way to do this is to copy the upper triangular part of the matrix to the lower triangular part

def symmetric(X):
    return X.triu() + X.triu(1).transpose(-1, -2)

X = torch.rand(3, 3)
A = symmetric(X)
assert torch.allclose(A, A.T)
print(A)

tensor([[0.3273, 0.9011, 0.4284],
        [0.9011, 0.4835, 0.8893],
        [0.4284, 0.8893, 0.1873]])


In [14]:
class LinearSymmetric(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(n_features, n_features))
    
    def forward(self, x):
        A = symmetric(self.weight)
        return x @ A

In [15]:
layer = LinearSymmetric(3)
out = layer(torch.rand(8, 3))

In [17]:
class Symmetric(nn.Module):
    def forward(self, X):
        return X.triu() + X.triu(1).transpose(-1, -2)

In [18]:
layer = nn.Linear(3, 3)
parametrize.register_parametrization(layer, "weight", Symmetric())

ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Symmetric()
    )
  )
)

In [19]:
A = layer.weight
assert torch.allclose(A, A.T)
print(A)

tensor([[ 0.2497, -0.3045, -0.4517],
        [-0.3045,  0.3438, -0.3698],
        [-0.4517, -0.3698, -0.2481]], grad_fn=<AddBackward0>)


In [20]:
class Skew(nn.Module):
    def forward(self, X):
        return X.triu() - X.triu(1).transpose(-1, -2)

In [21]:
cnn = nn.Conv2d(in_channels=5, out_channels=8, kernel_size=3)
parametrize.register_parametrization(cnn, "weight", Skew())
print(cnn.weight[0, 1])
print(cnn.weight[2, 2])

tensor([[-0.0663,  0.0441,  0.0988],
        [-0.0441,  0.1102, -0.1088],
        [-0.0988,  0.1088, -0.0356]], grad_fn=<SelectBackward0>)
tensor([[ 0.0100,  0.1078,  0.0305],
        [-0.1078, -0.0107,  0.1232],
        [-0.0305, -0.1232, -0.0985]], grad_fn=<SelectBackward0>)


In [23]:
layer = nn.Linear(3, 3)
print(f"Unparametrized: \n{layer}")
parametrize.register_parametrization(layer, "weight", Symmetric())
print(f"Parametrized: \n{layer}")

Unparametrized: 
Linear(in_features=3, out_features=3, bias=True)
Parametrized: 
ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Symmetric()
    )
  )
)


In [24]:
print(layer.parametrizations)
print(layer.parametrizations.weight)

ModuleDict(
  (weight): ParametrizationList(
    (0): Symmetric()
  )
)
ParametrizationList(
  (0): Symmetric()
)


In [25]:
layer.parametrizations.weight[0]

Symmetric()

In [26]:
print(dict(layer.named_parameters()))

{'bias': Parameter containing:
tensor([-0.5430,  0.3678,  0.4536], requires_grad=True), 'parametrizations.weight.original': Parameter containing:
tensor([[-0.4718, -0.3407, -0.0715],
        [ 0.2371,  0.2192, -0.4169],
        [ 0.2739, -0.5155, -0.4661]], requires_grad=True)}


In [27]:
layer.parametrizations.weight.original

Parameter containing:
tensor([[-0.4718, -0.3407, -0.0715],
        [ 0.2371,  0.2192, -0.4169],
        [ 0.2739, -0.5155, -0.4661]], requires_grad=True)

In [28]:
symmetric = Symmetric()
weight_orig = layer.parametrizations.weight.original
print(torch.dist(layer.weight, symmetric(weight_orig)))

tensor(0., grad_fn=<DistBackward0>)


In [29]:
class NoisyParametrization(nn.Module):
    def forward(self, X):
        print("Computing the Parametrization")
        return X
    
layer = nn.Linear(4, 4)
parametrize.register_parametrization(layer, "weight", NoisyParametrization())
print("Here, layer.weight is recomputed every time we call it")
foo = layer.weight + layer.weight.T
bar = layer.weight.sum()
with parametrize.cached():
    print("Here, layer.weight is recomputed only once")
    foo = layer.weight + layer.weight.T
    bar = layer.weight.sum()

Computing the Parametrization
Here, layer.weight is recomputed every time we call it
Computing the Parametrization
Computing the Parametrization
Computing the Parametrization
Here, layer.weight is recomputed only once
Computing the Parametrization


In [33]:
class CayleyMap(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.register_buffer("Id", torch.eye(n))
    
    def forward(self, X):
        return torch.linalg.solve(self.Id - X, self.Id + X)
    
layer = nn.Linear(3, 3)
parametrize.register_parametrization(layer, "weight", Symmetric())
parametrize.register_parametrization(layer, "weight", CayleyMap(3))
layer

ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Symmetric()
      (1): CayleyMap()
    )
  )
)

In [34]:
X = layer.weight
print(torch.dist(X.T @ X, torch.eye(3)))

tensor(1077.7062, grad_fn=<DistBackward0>)


In [36]:
# def right_inverse(self, X: Tensor) -> Tensor

class Skew(nn.Module):
    def forward(self, X):
        A = X.triu(1)
        return A - A.transpose(-1, -2)

    def right_inverse(self, A):
        return A.triu(1)

In [37]:
layer = nn.Linear(3, 3)
parametrize.register_parametrization(layer, "weight", Skew())
X = torch.rand(3, 3)
X = X - X.T
layer.weight = X
print(torch.dist(layer.weight, X))

tensor(0., grad_fn=<DistBackward0>)


In [42]:
layer = nn.Linear(3, 3)
print("Before:")
print(layer)
print(layer.weight)
parametrize.register_parametrization(layer, "weight", Skew())
print("\nParametrized:")
print(layer)
print(layer.weight)
parametrize.remove_parametrizations(layer, "weight", leave_parametrized=False)
print("\nAfter:")
print(layer)
print(layer.weight)

Before:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[-0.3721,  0.1208,  0.1343],
        [-0.0668,  0.2415,  0.3783],
        [ 0.2588,  0.1037,  0.5136]], requires_grad=True)

Parametrized:
ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Skew()
    )
  )
)
tensor([[ 0.0000,  0.1208,  0.1343],
        [-0.1208,  0.0000,  0.3783],
        [-0.1343, -0.3783,  0.0000]], grad_fn=<SubBackward0>)

After:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[0.0000, 0.1208, 0.1343],
        [0.0000, 0.0000, 0.3783],
        [0.0000, 0.0000, 0.0000]], requires_grad=True)


In [43]:
layer

Linear(in_features=3, out_features=3, bias=True)