In [2]:
import torch
import matplotlib.pyplot as plt
import numpy as np

plt.style.use('dark_background')
background_color = tuple(np.array([66, 68, 69])/255)
plt.rcParams['axes.facecolor'] = background_color
plt.rcParams['figure.facecolor'] = background_color

# Idea

Our linear layer is linearized and made positive-definite using low dimentional factorization with "residual" connection, $U^TU$ is being kept lipschitz with constant smaller than one.

$$ m >> n \qquad U \in \mathbb{R}^{n \times m} \qquad U^TU$$
$$ l_{U}(x) = \sigma((U^TU + I)x) = \sigma( U^T (Ux) + x ) $$


In [372]:
class InvertibleLayer(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(InvertibleLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = input_dim
        self.hiden_dim = hidden_dim
        
        self.U = torch.randn(input_dim, hidden_dim)/hidden_dim
        self.V = torch.randn(input_dim, hidden_dim)/hidden_dim
        self.b = torch.zeros(input_dim)
        
    def forward(self, X):
        X = X + (X @ self.U) @ self.U.T 
        X = X - (X @ self.V) @ self.V.T
        X = X + self.b
        return X
    
    def constraint(self):
        return torch.relu(torch.linalg.matrix_norm(self.U, ord=2).pow(2) - 1) + torch.relu(torch.linalg.matrix_norm(self.V, ord=2).pow(2) - 1)
    
    def logdet(self, X):
        logdetU = (torch.linalg.svd(U).S.pwd(2) + 1).log().sum()
        logdetV = (torch.linalg.svd(V).S.pwd(2) + 1).log().sum()
        return logdetU + logdetV
    
    def call(self, X):
        return self.forward(X), self.logdet(X)
    
class InvertibleActivation(torch.nn.Module):
    def __init__(self):
        super(InvertibleActivation, self).__init__()
        
    def forward(self, X):
        return torch.tanh(X)
    
    def logdet(self, X):
        return (1 - torch.tanh(X).pow(2)).log().sum()
    
    def call(self, X):
        t = torch.tanh(X)
        return t, (1 - t.pow(2)).log(2).sum()
    
class InvertibleSequential(torch.nn.Module):
    def __init__(self, *layers):
        super(InvertibleSequential, self).__init__()
        self.layers = layers
        
    def forward(self, X):
        for l in self.layers:
            X = l(X)
        return X
    
    def logdet(self, X):
        return self.call(X)[1]
    
    def call(self, X):
        ld = 0
        for l in self.layers:
            X, logdet = l.call(X)
            ld += logdet
        return X, ld

In [373]:
torch.linalg.svd(U.T @ U + torch.eye(100)).S.log().sum()

tensor(0.6737)

In [374]:
(torch.linalg.svd(U).S.pow(2) + 1).prod()

tensor(1.9615)

In [375]:
torch.linalg.det(U.T @ U + torch.eye(100)) * torch.linalg.det(-(V.T @ V) + torch.eye(100))

tensor(0.5713)

In [376]:
torch.linalg.det((-(V.T @ V) + torch.eye(100)) @ (U.T @ U + torch.eye(100)) )

tensor(0.5713)

In [377]:
torch.linalg.eigh((-(V.T @ V) + torch.eye(100)) @ (U.T @ U + torch.eye(100)) )

torch.return_types.linalg_eigh(
eigenvalues=tensor([0.5264, 0.9708, 0.9716, 0.9744, 0.9759, 0.9770, 0.9815, 0.9831, 0.9858,
        0.9862, 0.9884, 0.9913, 0.9916, 0.9949, 0.9953, 0.9955, 0.9962, 0.9963,
        0.9970, 0.9976, 0.9977, 0.9979, 0.9981, 0.9983, 0.9985, 0.9987, 0.9988,
        0.9988, 0.9990, 0.9990, 0.9992, 0.9992, 0.9994, 0.9994, 0.9995, 0.9996,
        0.9996, 0.9997, 0.9997, 0.9997, 0.9998, 0.9998, 0.9998, 0.9998, 0.9999,
        0.9999, 0.9999, 0.9999, 0.9999, 1.0000, 1.0000, 1.0000, 1.0001, 1.0001,
        1.0001, 1.0001, 1.0002, 1.0002, 1.0002, 1.0003, 1.0003, 1.0003, 1.0004,
        1.0004, 1.0005, 1.0005, 1.0006, 1.0007, 1.0008, 1.0008, 1.0009, 1.0009,
        1.0012, 1.0013, 1.0014, 1.0016, 1.0017, 1.0018, 1.0020, 1.0021, 1.0023,
        1.0028, 1.0031, 1.0033, 1.0036, 1.0047, 1.0049, 1.0075, 1.0082, 1.0107,
        1.0118, 1.0138, 1.0155, 1.0186, 1.0209, 1.0238, 1.0266, 1.0280, 1.0307,
        1.0886]),
eigenvectors=tensor([[ 0.0957,  0.2201, -0.0051,  ..., -0.

In [357]:
x = torch.rand(100)
a = 20
U = torch.rand(10,100)/a
V = torch.rand(10,100)/a
M = (U.T @ U) - (V.T @ V) + torch.eye(100)
M = (U.T @ U) + torch.eye(100)
M = (U.T @ V) + torch.eye(100)
M = (-(V.T @ V) + torch.eye(100)) @ ((U.T @ U) + torch.eye(100))

In [309]:
torch.linalg.eig(M).eigenvalues[0]

tensor(-4.7690+0.j)

In [310]:
torch.linalg.eigh(M).eigenvalues[-1]

tensor(1.5540)

In [311]:
torch.linalg.svd(M).S

tensor([4.9816, 1.1284, 1.1219, 1.0972, 1.0780, 1.0704, 1.0610, 1.0513, 1.0472,
        1.0407, 1.0093, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 0.9567, 0.9503, 0.9311, 0.9291, 0.9213, 0.9173, 0.9029, 0.8993,
        0.8762])

In [295]:
torch.linalg.matrix_norm(M, ord=2)

tensor(6.2303e+08)

In [286]:
torch.linalg.det(M)

tensor(-5.3199)

In [267]:
torch.linalg.svd(M).S.prod()

tensor(52.6636)

In [268]:
(torch.linalg.svd(V).S * torch.linalg.svd(U).S)

tensor([2.5742, 0.1245, 0.1088, 0.0968, 0.0933, 0.0817, 0.0707, 0.0616, 0.0545,
        0.0467])

In [269]:
torch.linalg.eigh(M).eigenvalues.prod()

tensor(23.4800)

In [270]:
(torch.linalg.svd(U).S * torch.linalg.svd(V).S + 1)

tensor([3.5742, 1.1245, 1.1088, 1.0968, 1.0933, 1.0817, 1.0707, 1.0616, 1.0545,
        1.0467])

In [271]:
torch.linalg.svd(U).S.pow(2)

tensor([2.6375, 0.1253, 0.1056, 0.1016, 0.0970, 0.0792, 0.0758, 0.0633, 0.0553,
        0.0507])

In [272]:
torch.linalg.matrix_norm(U.T @ U, ord=2)

tensor(2.6375)

In [273]:
torch.svd(V).S.pow(2)

tensor([2.5124, 0.1237, 0.1120, 0.0922, 0.0898, 0.0842, 0.0661, 0.0600, 0.0538,
        0.0431])

In [274]:
torch.linalg.matrix_norm(V, ord=2).pow(2)

tensor(2.5124)

In [275]:
torch.svd(V).S.pow(2) - 1

tensor([ 1.5124, -0.8763, -0.8880, -0.9078, -0.9102, -0.9158, -0.9339, -0.9400,
        -0.9462, -0.9569])