In [17]:
import torch 
from src.utils.iwo import get_Q, complete_orthonormal_basis, get_basis

L = 10
m_list = [torch.randn(l - 1, l).to(torch.float64) for l in reversed(range(2, L + 1))]

# Sanity check QR decomposition and basis completion.
QL = get_Q(m_list[0])
qL = complete_orthonormal_basis(QL)
assert all(torch.le(m_list[0] @ qL.t(), 1e-5))

# Run basis generation.
B = get_basis(m_list)

# The projection of the l-th basis vector after propagating it through matrices W_L ... W_l should lie in the null-space of W_{l-1}.
# We verify this in continuation.

# Flip the tensor, so that the basis vector with the smallest importance comes first.
B = torch.flip(B, dims=[1])

for i in range(len(m_list) - 1):
    t = B[:, i : i + 1]
    for j in range(i + 1):
        t = m_list[j] @ t
    # Assert if the the projection is indeed inside the null-space.
    assert all(torch.le(t, 1e-3))

In [10]:
B.pow(2).sum(0).sqrt()

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])