In [1]:
import torch
import torch.nn as nn

**Resources**

1. https://pytorch.org/tutorials/intermediate/parametrizations.html
2. https://pytorch.org/docs/stable/generated/torch.nn.utils.parametrize.register_parametrization.html
3. https://chatgpt.com/share/67fcda61-f324-8005-a335-c71066e41bcd

Implementing a Linear Layer Equivalence

In [None]:
# manually enforcing a parameter
class LinearLayer(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(n_features, n_features))  # https://chatgpt.com/share/67fcd15f-2044-8005-92fd-0755e399210c
    def forward(self, x):
        return torch.matmul(self.weight, x)

Introduction to Parametrization

In [10]:
class Symmetry(nn.Module):
    def forward (self, x: torch.Tensor):
        return x.triu() + x.triu(1).transpose(-1, -2)

In [11]:
layer = nn.Linear(3, 3)

for name, param in layer.named_parameters():
    print(name, param.shape)

weight torch.Size([3, 3])
bias torch.Size([3])


In [12]:
import torch.nn.utils.parametrize as P

P.register_parametrization(layer, 'weight', Symmetry())

ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Symmetry()
    )
  )
)

In [13]:
A = layer.weight
print(A)

tensor([[ 0.5218,  0.3776,  0.4783],
        [ 0.3776, -0.2035, -0.4937],
        [ 0.4783, -0.4937,  0.1725]], grad_fn=<AddBackward0>)


In [None]:
for param in layer.weight:  # accessing the parametrized weights
    print(param)

tensor([0.5218, 0.3776, 0.4783], grad_fn=<UnbindBackward0>)
tensor([ 0.3776, -0.2035, -0.4937], grad_fn=<UnbindBackward0>)
tensor([ 0.4783, -0.4937,  0.1725], grad_fn=<UnbindBackward0>)


In [None]:
for param in layer.parametrizations.weight.original:  # accessing the origianal weights
    print(param)

tensor([0.5218, 0.3776, 0.4783], grad_fn=<UnbindBackward0>)
tensor([ 0.2867, -0.2035, -0.4937], grad_fn=<UnbindBackward0>)
tensor([-0.5015, -0.0754,  0.1725], grad_fn=<UnbindBackward0>)
