In [109]:
%load_ext autoreload
%autoreload 2

import torch 
from src.utils.iwo import get_basis

L = 10
W_list = [torch.rand(l - 1, l).to(torch.float64) for l in reversed(range(2, L + 1))]
W_list = (
    [torch.rand(5, L).to(torch.float64)]
    + W_list[5:-2]
    + [torch.rand(1, 3).to(torch.float64)]
)
b_list = get_basis(W_list)

# Test orthogonality
B = torch.concat(b_list, axis=1)
eye = torch.eye(B.size(0), dtype=B.dtype, device=B.device)
assert torch.allclose(B.t() @ B, eye, atol=1e-08)

# Test that the basis vectors are indeed inside the null space of the next smaller matrix
B_flipped = b_list[::-1]  # Re-order from least important to most important
for i in range(len(W_list)):
    reduction = W_list[i].shape[1] - W_list[i].shape[0]
    t = B_flipped[i]
    for j in range(i + 1):
        t = W_list[j] @ t
    # Assert if the the projection is indeed inside the null-space.
    assert torch.any(torch.le(t, 1e-6))


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [110]:
[0.5] * 4

[0.5, 0.5, 0.5, 0.5]

In [None]:

new_dtype = torch.float64

old_dtype = W_list[0].dtype
device = W_list[0].device

if old_dtype != new_dtype:
    W_list = [m.to(new_dtype) for m in W_list]


b_list = []
W_prod = torch.eye(W_list[0].shape[1], device=device, dtype=new_dtype)

for i, W in enumerate(W_list):
    W_prod = W @ W_prod
    reduction = W.shape[1] - W.shape[0]
    T = torch.concat([W_prod.t()] + b_list, axis=1)
    Qr, _ = torch.linalg.qr(T, mode="complete")
    b_list.append(torch.flip(Qr[:, -reduction:], dims=[1]))

T = torch.concat(b_list, axis=1)
reduction = T.shape[0] - T.shape[1]
Qr, _ = torch.linalg.qr(T, mode="complete")
b_list.append(torch.flip(Qr[:, -reduction:], dims=[1]))
b_list.reverse()
b_list = [b.to(old_dtype) for b in b_list]

In [100]:
b_list

[tensor([[0.3438],
         [0.2543],
         [0.3638],
         [0.3072],
         [0.2287],
         [0.2260],
         [0.4675],
         [0.2897],
         [0.2763],
         [0.3289]], dtype=torch.float64),
 tensor([[-0.3589, -0.5186],
         [-0.1743, -0.3852],
         [-0.3736,  0.2272],
         [-0.1395,  0.3791],
         [ 0.0725, -0.2839],
         [ 0.4672,  0.1903],
         [ 0.1241, -0.0115],
         [-0.2950,  0.4221],
         [ 0.2918,  0.2035],
         [ 0.5202, -0.2254]], dtype=torch.float64),
 tensor([[-0.2979],
         [ 0.3589],
         [ 0.2030],
         [-0.0258],
         [-0.0663],
         [-0.5247],
         [-0.2214],
         [-0.1384],
         [ 0.6031],
         [ 0.1702]], dtype=torch.float64),
 tensor([[-0.2443],
         [ 0.1451],
         [-0.5885],
         [ 0.2980],
         [ 0.2255],
         [-0.3754],
         [ 0.4909],
         [ 0.1405],
         [-0.0365],
         [-0.1739]], dtype=torch.float64),
 tensor([[-0.2124,  0.3430, 

In [80]:
b_list[1].shape

torch.Size([10, 2])

In [81]:
b_list[2].shape

torch.Size([10, 1])

In [82]:
b_list[3].shape


torch.Size([10, 1])

tensor([[ 1.1102e-16,  2.1684e-17, -2.0817e-17,  8.6736e-17,  3.7470e-16],
        [-1.0755e-16, -9.6494e-18, -7.9364e-17,  1.0582e-16,  2.1511e-16],
        [ 5.5511e-17, -2.6888e-17,  1.8735e-16,  3.1225e-17,  3.0531e-16],
        [ 2.8103e-16,  2.3202e-17,  1.6306e-16,  0.0000e+00, -6.3317e-17],
        [-5.5511e-17, -1.1102e-16,  5.5511e-17,  1.1102e-16,  1.1102e-16]],
       dtype=torch.float64)
tensor([[ 2.7756e-17],
        [ 4.1633e-17],
        [ 8.3267e-17],
        [-1.3878e-17]], dtype=torch.float64)
tensor([[-4.2327e-16],
        [ 5.5511e-16],
        [-3.6082e-16]], dtype=torch.float64)
tensor([[-4.3802e-16, -4.6491e-16]], dtype=torch.float64)


In [36]:
reduction = W_list[0].shape[1] - W_list[0].shape[0]

In [51]:
W_list[0] @ B_flipped[:, 4]

tensor([-2.7756e-17,  3.4694e-17, -1.6653e-16, -2.2204e-16,  5.5511e-17],
       dtype=torch.float64)

In [52]:
W_list[2] @ W_list[1] @ W_list[0] @ B_flipped[:, 6]

tensor([-2.2204e-16, -1.3323e-15, -5.5511e-17], dtype=torch.float64)

In [44]:
W_list[4] @ W_list[3] @ W_list[2] @ W_list[1] @ W_list[0] @ B_flipped[:, 9]

tensor([18.0246], dtype=torch.float64)