In [541]:
import torch
from monotonenorm import GroupSort, get_normed_weights
from torch.nn.utils.parametrizations import orthogonal

In [319]:
input_dim = 16
hidden_dim = 16
depth = 4
lip = 5
#activation = lambda: torch.nn.ReLU()
activation = lambda: GroupSort(hidden_dim//2)
#from monotonenorm import direct_norm
direct_norm = lambda x, *args, **kwargs: x  # make it a normal network

input = torch.randn(1000, input_dim)

In [320]:
layers = [
  torch.nn.Identity(),
  direct_norm(torch.nn.Linear(input_dim,hidden_dim), kind="one", alpha=lip ** ( 1 / depth)),
  activation()
]
for i in range(depth - 2):
  layers.append(
    direct_norm(torch.nn.Linear(hidden_dim,hidden_dim), kind="one", alpha=lip ** ( 1 / depth))
  )
  layers.append(activation())
layers.append(
  direct_norm(torch.nn.Linear(hidden_dim,1), kind="one", alpha=lip ** ( 1 / depth))
)
model = torch.nn.Sequential(*layers).requires_grad_(False)


for i in model.modules():
  if isinstance(i, torch.nn.Linear):
    torch.nn.init.orthogonal_(i.weight)
    #i.weight.data *= input_dim ** (.5 / depth)

In [321]:
for i in range(len(model)):
  print(f"{i}, {str(model[i]):<60} std {model[:i+1](input).std(0).mean().item():.4f}, mean {model[:i+1](input).mean().item():.4f}")

0, Identity()                                                   std 1.0037, mean 0.0125
1, Linear(in_features=16, out_features=16, bias=True)           std 1.0036, mean -0.0849
2, GroupSort(num_groups: {self.n_groups})                       std 0.8303, mean -0.0849
3, Linear(in_features=16, out_features=16, bias=True)           std 0.8275, mean -0.0705
4, GroupSort(num_groups: {self.n_groups})                       std 0.7199, mean -0.0705
5, Linear(in_features=16, out_features=16, bias=True)           std 0.7196, mean -0.1614
6, GroupSort(num_groups: {self.n_groups})                       std 0.6303, mean -0.1614
7, Linear(in_features=16, out_features=1, bias=True)            std 0.5381, mean 0.0939


want: f(x) ~ P(0,1)
g(x) ~ P(0,1) ?

want:
f(x) = 1/idim(g(x) + 1x) ~ P(0,1), x_i ~ P(0,1) --> sum(x_i) ~ P(0,i_dim)
g(x) ~ P(0,.5), 1x ~ P(0,.5)

In [728]:
do_norm = lambda x: get_normed_weights(x, kind="one", alpha=1, always_norm=False, vectorwise=True)

In [729]:
#torch.manual_seed(1)
i = torch.randn(100, 2)
t = torch.randn(2,2)
torch.nn.init.orthogonal_(t)
o = i @ do_norm(t).T

In [730]:
o.var(0).mean(), i.var(0).mean()

(tensor(0.8168), tensor(1.1281))

In [731]:
t, do_norm(t)

(tensor([[ 0.1942, -0.9810],
         [-0.9810, -0.1942]]),
 tensor([[ 0.1653, -0.8347],
         [-0.8347, -0.1653]]))

In [733]:
p = 1#float('inf')
print(do_norm(t).norm(dim=0, p=p), do_norm(t).norm(dim=1, p=p))
print(t.norm(dim=0, p=2), t.norm(dim=1, p=2))


tensor([1.0000, 1.0000]) tensor([1.0000, 1.0000])
tensor([1.0000, 1.0000]) tensor([1.0000, 1.0000])


In [738]:
#householder reflection
def householder_reflection(v):
  return torch.eye(v.shape[0]) - 2 * v @ v.T / (v.T @ v)

In [740]:
hh = householder_reflection(torch.tensor([1., 0.]))

In [741]:
v = torch.tensor([1., 0.])

In [743]:
v @ hh

tensor([-1., -2.])