In [149]:
import torch
from monotonenorm import GroupSort, get_normed_weights
from torch.nn.utils.parametrizations import orthogonal
import matplotlib.pyplot as plt
import numpy as np

In [161]:
dims = (5, 1000)
norminfs, norm1s, norm20s, norm21s = [], [], [], []
for i in range(100):
  torch.manual_seed(i)
  t = torch.randn(*dims)
  torch.nn.init.orthogonal_(t)
  norm1s.append(t.norm(dim=0, p=1).max().item())
  norminfs.append(t.norm(dim=1, p=1).max().item())
  norm20s.append(t.norm(dim=0, p=2).max().item())
  norm21s.append(t.norm(dim=1, p=2).max().item())
  
print(f"mean norm1: {np.mean(norm1s):.4f}, mean norminf: {np.mean(norminfs):.4f}, mean norm20: {np.mean(norm20s):.4f}, mean norm21: {np.mean(norm21s):.4f}")

mean norm1: 0.2907, mean norminf: 25.4665, mean norm20: 0.1470, mean norm21: 1.0000


In [151]:
class Lambda(torch.nn.Module):
    def __init__(self, f):
        super().__init__()
        self.f = f

    def forward(self, x):
        return self.f(x)

In [152]:
input_dim = 256
hidden_dim = 16
depth = 4
per_layer_lip = hidden_dim**.5
#activation = lambda: torch.nn.ReLU()
activation = lambda: GroupSort(hidden_dim // 2)
from monotonenorm import direct_norm
#direct_norm = lambda x, *args, **kwargs: x  # make it a normal network

input = torch.randn(1000, input_dim)
lip = per_layer_lip**depth

In [153]:
layers = [
  torch.nn.Identity(),
  direct_norm(torch.nn.Linear(input_dim,hidden_dim), kind="one-inf", alpha=per_layer_lip),
  activation()
]
for i in range(depth - 2):
  layers.append(
    direct_norm(torch.nn.Linear(hidden_dim,hidden_dim), kind="inf", alpha=per_layer_lip)
  )
  layers.append(activation())
layers.append(
  direct_norm(torch.nn.Linear(hidden_dim,1), kind="inf", alpha=per_layer_lip)
)
layers.append(
  Lambda(lambda x: x/lip)
)
model = torch.nn.Sequential(*layers).requires_grad_(False)


for i in model.modules():
  if isinstance(i, torch.nn.Linear):
    #continue
    torch.nn.init.orthogonal_(i.weight)
    #i.weight.data *= input_dim ** (.5 / depth)

In [154]:
for i in range(len(model)):
  print(f"{i}, {str(model[i]):<60} std {model[:i+1](input).std(0).mean().item():.4f}, mean {model[:i+1](input).mean().item():.4f}")

0, Identity()                                                   std 0.9993, mean -0.0030
1, Linear(in_features=256, out_features=16, bias=True)          std 1.0077, mean -0.0224
2, GroupSort(num_groups: {self.n_groups})                       std 0.8351, mean -0.0224
3, Linear(in_features=16, out_features=16, bias=True)           std 0.8329, mean -0.2274
4, GroupSort(num_groups: {self.n_groups})                       std 0.7247, mean -0.2274
5, Linear(in_features=16, out_features=16, bias=True)           std 0.7246, mean 0.2984
6, GroupSort(num_groups: {self.n_groups})                       std 0.6585, mean 0.2984
7, Linear(in_features=16, out_features=1, bias=True)            std 0.6490, mean 0.2821
8, Lambda()                                                     std 0.0025, mean 0.0011


In [158]:
# monotonic network with reasonable size starting configuration

input_dim = 256
hidden_dim = 16
depth = 4
per_layer_lip = hidden_dim**.5
activation = lambda: GroupSort(hidden_dim // 2)
from monotonenorm import direct_norm
#direct_norm = lambda x, *args, **kwargs: x  # make it a normal network

input = torch.randn(1000, input_dim)
lip = per_layer_lip**depth



class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()
    layers = [
      direct_norm(torch.nn.Linear(input_dim,hidden_dim), kind="one-inf"),
      activation()
    ]
    for i in range(depth - 2):
      layers.append(
        direct_norm(torch.nn.Linear(hidden_dim,hidden_dim), kind="inf")
      )
      layers.append(activation())
    layers.append(
      direct_norm(torch.nn.Linear(hidden_dim,1), kind="inf")
    )
    self.layers = torch.nn.ModuleList(layers)

  def forward(self, x):
    for i in self.layers:
      y = i(y)
    return y + x.sum(1, keepdim=True)

In [160]:
model = Model()
print(model)

Model(
  (layers): ModuleList(
    (0): Linear(in_features=256, out_features=16, bias=True)
    (1): GroupSort(num_groups: {self.n_groups})
    (2): Linear(in_features=16, out_features=16, bias=True)
    (3): GroupSort(num_groups: {self.n_groups})
    (4): Linear(in_features=16, out_features=16, bias=True)
    (5): GroupSort(num_groups: {self.n_groups})
    (6): Linear(in_features=16, out_features=1, bias=True)
  )
)
