# Toy Example

In [1]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".here"])

# append to path
sys.path.append(str(root))

%load_ext autoreload
%autoreload 2

## Import Packages

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from survae.transforms.bijections.functional.householder import householder_matrix, householder_matrix_fast
from survae.transforms.bijections.linear_orthogonal import LinearHouseholder, LinearOrthogonal

import matplotlib.pyplot as plt

%matplotlib inline

In [3]:
def test_orthogonal(Q):
    
    I = torch.eye(Q.shape[0])
    torch.testing.assert_close(I, Q.T @ Q) 
    torch.testing.assert_close(I, Q @ Q.T) 
    torch.testing.assert_close(Q.inverse(), Q.t(), rtol=1e-3, atol=1e-3)
    torch.testing.assert_close(torch.linalg.slogdet(Q)[1], torch.zeros(()), rtol=1e-3, atol=1e-3)
    
    return True

In [4]:
num_dimensions = 1000
num_reflections = 10

# random iniitialization (for fixed)
vs = torch.randn(num_reflections, num_dimensions)

**Source**: [nflows]()

In [6]:
num_dimensions = 1000
num_reflections = 10

# random iniitialization (for fixed)
vs = torch.randn(num_reflections, num_dimensions)

Q = householder_matrix(vs, loop=True)

test_orthogonal(Q)

True

**Source**: [Invert to Invert]()

In [7]:
Q_bmm = householder_matrix(vs, loop=False)
# check shape
assert Q_bmm.shape[0] == vs.shape[1]
# check orthogonality
torch.testing.assert_close(Q, Q_bmm)
test_orthogonal(Q_bmm)

True

**Source**: [Fast Householder Matrix]()

In [8]:
stride = 2
Q_fast = householder_matrix_fast(vs, stride)

test_orthogonal(Q_fast)
torch.testing.assert_close(Q, Q_fast)
torch.testing.assert_close(Q_bmm, Q_fast)

In [9]:
# %timeit Q = householder_matrix(vs, loop=True)
# %timeit Q_bmm = householder_matrix(vs, loop=False)
# %timeit Q_fast = householder_matrix_fast(vs, 2)

In [10]:
# from tqdm.notebook import tqdm
# import itertools

# num_features = [2, 10, 50, 100, 1_000, 
#                 10_000, 100_000
#                ]
# num_reflections = [
#     2, 2, 5, 10, 50, 50, 100
# ]


# methods = [
#     "matrix", 
#     "loops",
#     "fast", 
#     # "base"
#           ]
# results = {imethod: list() for imethod in methods}

# def run_method(num_dims, num_hh, method: str):
    
#     # random iniitialization (for fixed)
#     vs = torch.randn(num_hh, num_dims)
    
#     if method == "loops":
#         res = %timeit -n10 -r10 -o construct_householder_matrix(vs, loop=True)
#         return res
#     elif method == "matrix":
#         res = %timeit -n10 -r10 -o construct_householder_matrix(vs, loop=False)
#         return res
#     elif method == "fast":
#         res = %timeit -n10 -r10 -o fast_householder_matrix(vs, 2)
#         return res
#     else:
#         raise ValueError(f"Unrecognized method: {method}")

# for ifeatures, ireflections in tqdm(zip(num_features, num_reflections)):
    
#     # benchmarks
#     for imethod in tqdm(methods):
        
#         ires = run_method(ifeatures, ireflections, imethod)
        
#         # extract infor
#         results[imethod].append((ires.average, ires.stdev))

In [11]:
num_batches = 32
num_features = 10
num_reflections = 2

X = torch.randn(32, 10)

In [13]:


lin_hh = LinearHouseholder(
    num_features=num_features, num_reflections=num_reflections, 
    fixed=True
)

Z, ldj = lin_hh(X)

# check shape
assert Z.shape == X.shape


X_approx = lin_hh.inverse(Z)

# check inverse
torch.testing.assert_close(X, X_approx)

In [14]:


lin_hh = LinearHouseholder(
    num_features=num_features, num_reflections=num_reflections, 
    fixed=False, fast=False, loop=True
)

Z, ldj = lin_hh(X)

# check shape
assert Z.shape == X.shape


X_approx = lin_hh.inverse(Z)

# check inverse
torch.testing.assert_close(X, X_approx)

In [15]:
lin_hh_mat = LinearHouseholder(
    num_features=num_features, num_reflections=num_reflections, 
    fixed=False, fast=False, loop=False
)

Z, ldj = lin_hh_mat.forward(X)

# check shape
assert Z.shape == X.shape


X_approx = lin_hh_mat.inverse(Z)

# check inverse
torch.testing.assert_close(X, X_approx)

In [16]:
lin_hh_fast = LinearHouseholder(
    num_features=num_features, num_reflections=num_reflections, 
    fixed=False, fast=True, loop=False
)

Z, ldj = lin_hh_fast.forward(X)

# check shape
assert Z.shape == X.shape


X_approx = lin_hh_fast.inverse(Z)

# check inverse
torch.testing.assert_close(X, X_approx)

In [167]:
num_dimensions = 100
num_reflections = 10

# random iniitialization (for fixed)
vs = torch.randn(num_reflections, num_dimensions)

V_t = vs.unsqueeze(2).transpose(1, 2)
V = vs.unsqueeze(2)
I = torch.eye(n_dimensions, dtype=vs.dtype, device=vs.device)

U = I - 2 * torch.bmm(V, V_t) / torch.bmm(V_t, V)
Q = torch.chain_matmul(*U)
Q_ = torch.linalg.multi_dot(tuple(U))

torch.testing.assert_close(Q, Q_)

In [None]:
        V = self.weights
        V_t = self.weights.transpose(1, 2)
        U = self.I - 2 * torch.bmm(V, V_t) / torch.bmm(V_t, V)
        W = torch.chain_matmul(*U)
        return W

In [131]:
num_dims = 100
num_reflections = 10

# random iniitialization (for fixed)
vs = torch.randn(num_dims, num_reflections)
vs = vs.transpose(-1, -2)
print(vs.shape)
# close to ideneity
vs = torch.eye(num_dims, num_reflections)
vs += torch.randn_like(vs) * 0.1
vs = vs.transpose(-1, -2)
print(vs.shape)

torch.Size([10, 100])
torch.Size([10, 100])


In [132]:
Q_fast = fast_householder_matrix(vs, 10)
Q = construct_householder_matrix(vs)

torch.testing.assert_close(Q, Q_fast)
Q_fast.shape

torch.Size([100, 100])

In [None]:
from tqdm.notebook import tqdm

num_features = [2, 10, 50, 100, 1_000, 
                10_000, 100_000
               ]
methods = [
    # "householder", 
    "cayley",
    # "matrix_exp", 
    # "base"
          ]
results = {imethod: list() for imethod in methods}

def run_method(num_dims, method: str):
    
    X = torch.randn((batch_size, num_dims))
    
    if method == "householder":
        lin = LinearOrthogonal(num_dims, norm="householder")
    elif method == "cayley":
        lin = LinearOrthogonal(num_dims, norm="cayley")
    elif method == "matrix_exp":
        lin = LinearOrthogonal(num_dims, norm="matrix_exp")
    elif method == "base":
        lin = Linear(num_dims)
    else:
        raise ValueError(f"Unrecognized method: {method}")
        
    
    res = %timeit -n10 -r10 -o lin.forward(X)
    return res

for ifeatures in tqdm(num_features):
    
    # benchmarks
    for imethod in tqdm(methods):
        
        ires = run_method(ifeatures, imethod)
        
        # extract infor
        results[imethod].append((ires.average, ires.stdev))

In [None]:
           if self.fixed:
                # init randomly
                init = torch.randn(self.width, self.n_reflections)
            else:
                # init close to identity
                init = torch.eye(self.width, self.n_reflections)
                init += torch.randn_like(init) * 0.1
            Vs = init.transpose(-1, -2)

In [23]:
with torch.no_grad():
    num_dims = 100
    num_reflections = 10
    strides = 2
    v = torch.ones((num_reflections, num_dims))
    
    nn.init.orthogonal_(v)
    scale = torch.sum(v ** 2, dim=-1)
    Q = construct_householder_matrix(v)
    Q_fast = fast_householder_matrix(v, strides)
    
    
torch.testing.assert_close(Q, Q_fast)
test_orthogonal(Q);
test_orthogonal(Q_fast);

## PyTorch

### From Docs

In [24]:
num_dims = 10
num_reflections = 2
v = torch.ones((num_dims, num_reflections))
h, tau = torch.geqrf(v)

Q_t = torch.linalg.householder_product(h, tau)
v.shape, h.shape, tau.shape, Q_t.shape

(torch.Size([10, 2]),
 torch.Size([10, 2]),
 torch.Size([2]),
 torch.Size([10, 2]))

In [25]:
with torch.no_grad():
    num_dims = 10
    num_reflections = 2
    strides = 2
    v = torch.ones((num_reflections, num_dims))
    
    nn.init.orthogonal_(v)
    scale = 2 / (v * v).sum(dim=1)
    Q_t = torch.linalg.householder_product(v.T, scale)
    

# assert Q_t.shape == (num_dims, num_dims)
# test_orthogonal(Q_t);

### Create orthogonal Matrix

In [None]:
batch_size = 32
num_dims = 100
X = torch.randn((batch_size, num_dims))

In [38]:
from survae.transforms.bijections.linear_orthogonal import LinearOrthogonal
from survae.transforms.bijections.linear import Linear

batch_size = 32
num_dims = 50
X = torch.randn((batch_size, num_dims))

lin_ortho = LinearOrthogonal(num_dims, norm="matrix_exp")

Z, ldj = lin_ortho.forward(X)
assert Z.shape == X.shape
assert ldj.shape[0] == X.shape[0]

# test_orthogonal(lin_ortho.weight)
# test_orthogonal(lin_ortho.weight_inv)

In [37]:
lin_ortho = LinearOrthogonal(num_dims, norm="householder")
%timeit lin_ortho.forward(X)
lin_ortho = LinearOrthogonal(num_dims, norm="cayley")
%timeit lin_ortho.forward(X)
lin_ortho = LinearOrthogonal(num_dims, norm="matrix_exp")
%timeit lin_ortho.forward(X)
lin = Linear(num_dims)
%timeit lin.forward(X)

11.1 ms ± 161 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
12.1 ms ± 167 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.93 ms ± 94.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.69 ms ± 53.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [103]:
# from tqdm.notebook import tqdm

# num_features = [2, 10, 50, 100, 1_000, 
#                 10_000, 100_000
#                ]
# methods = [
#     # "householder", 
#     "cayley",
#     # "matrix_exp", 
#     # "base"
#           ]
# results = {imethod: list() for imethod in methods}

# def run_method(num_dims, method: str):
    
#     X = torch.randn((batch_size, num_dims))
    
#     if method == "householder":
#         lin = LinearOrthogonal(num_dims, norm="householder")
#     elif method == "cayley":
#         lin = LinearOrthogonal(num_dims, norm="cayley")
#     elif method == "matrix_exp":
#         lin = LinearOrthogonal(num_dims, norm="matrix_exp")
#     elif method == "base":
#         lin = Linear(num_dims)
#     else:
#         raise ValueError(f"Unrecognized method: {method}")
        
    
#     res = %timeit -n10 -r10 -o lin.forward(X)
#     return res

# for ifeatures in tqdm(num_features):
    
#     # benchmarks
#     for imethod in tqdm(methods):
        
#         ires = run_method(ifeatures, imethod)
        
#         # extract infor
#         results[imethod].append((ires.average, ires.stdev))


In [104]:
# fig, ax = plt.subplots()


# for imethod, istats in results.items():
    
#     means, stdevs = zip(*istats)
#     upper = [imean + istd for imean, istd in zip(means, stdevs)]
    
#     ax.plot(num_features, means, label=imethod)
#     # ax.plot(num_features, upper)
    
# ax.set(
#     xlabel="Features",
#     ylabel="Time (secs)",
#     yscale="log"
# )
# plt.legend()
# plt.show()

#### Convolutional

In [105]:
batch_size = 32
n_channels = 3
num_height = 10
num_width = 10
X_img = torch.randn((batch_size, n_channels, num_height, num_width))

In [106]:
conv = nn.Conv2d(n_channels, n_channels, kernel_size=1, bias=False)

In [107]:
Z_img = conv(X_img)
Z_img_ = F.conv2d(X_img, conv.weight)
torch.testing.assert_close(Z_img, Z_img_)

#### orthogonal parameterization

In [108]:
conv = nn.Conv2d(n_channels, n_channels, kernel_size=1, bias=False)
nn.init.orthogonal_(conv.weight)
ortho_conv = torch.nn.utils.parametrizations.orthogonal(
    conv, 
    orthogonal_map="householder" # options: "caley", "householder", "matrix_exp"
)

In [198]:
conv.weight, ortho_conv.kernel

AttributeError: 'ParametrizedConv2d' object has no attribute 'kernel'

In [187]:
Z_img = ortho_conv(X_img)
Z_img_ = F.conv2d(X_img, ortho_conv.weight)
torch.testing.assert_close(Z_img, Z_img_)

In [134]:
V = torch.randn((num_dims, num_dims))
param = nn.Parameter(V)
ortho_param = torch.nn.utils.parametrizations.orthogonal(param)

ValueError: Module 'Parameter containing:
tensor([[ 0.8983, -0.2604, -0.4213,  ...,  0.0279, -0.6643, -0.4095],
        [ 0.3463,  0.7041, -0.4180,  ..., -0.2793,  0.0466, -0.7407],
        [ 0.3860,  0.5907, -0.2380,  ..., -0.5145,  0.2181,  0.2852],
        ...,
        [-1.1514, -0.9348,  0.3687,  ..., -0.1956, -1.3657,  0.9945],
        [-1.1207, -0.3512, -0.7432,  ..., -0.2833,  0.6492, -1.2670],
        [ 0.5332, -0.0198, -0.3383,  ..., -0.3099,  0.4659, -0.8023]],
       requires_grad=True)' has no parameter ot buffer with name 'weight'

In [None]:
nn.init.orthogonal_

In [102]:
tau

tensor([1.4472, 1.5000])

In [78]:
Q.shape

torch.Size([100, 100])

In [65]:
%timeit construct_householder_matrix(v)
%timeit fast_householder_matrix(v, 2)

925 µs ± 9.43 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
413 µs ± 2.76 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [69]:
I = torch.eye(num_dims)
I.shape

torch.Size([100, 100])

In [72]:
(Q.T @ Q).shape, (Q @ Q.T).shape

(torch.Size([10, 10]), torch.Size([10, 10]))