In [1]:
import numpy as np

In [2]:
def extend_matrix(q):
    num_rows, num_cols = q.shape
    assert num_rows == num_cols, "Matrix must be square"
    q_ext = np.eye(num_rows + 1)
    q_ext[1:, 1:] = q
    return q_ext

In [3]:
class representation:
    def __init__(self,W):
        self.W = W
        self.n = [W[0].shape[1]-1] + [w.shape[0] for w in W]
        self.L = len(self.n)-1
        
    def QRDecomposition(self):
        n = self.n   
        Q = []
        R = []
        U = [np.zeros((n[1],n[0]+1))]
        W_current = self.W[0]
        
        for s in range(self.L-1):
            Q_cur,R_cur = np.linalg.qr(W_current, mode="complete")
            
            Q.append(Q_cur)
            R.append(R_cur[:n[0]+s+1])

            Q_ext = extend_matrix(Q_cur)
            Q_ext_t = extend_matrix(np.transpose(Q_cur))

            QP = np.copy(Q_ext)
            QP[:,:n[0]+s+2] = np.zeros((n[s+1]+1,n[0]+s+2))
            U.append(self.W[s+1] @ QP @ Q_ext_t)

            W_current = (self.W[s+1] @ Q_ext)[:,:n[0]+s+2]

        R.append(W_current)
        
        self.Q = Q
        self.R = R
        self.U = U
        
        return (Q,R,U)
    
    def printShapes(self):
        print("Q    shapes:",[q.shape for q in self.Q])
        print("R    shapes:",[r.shape for r in self.R])
        print("U    shapes:",[u.shape for u in self.U])
        return

In [4]:
def testRepresentation(verbose=False):
    n = [3, 5, 8, 8, 2]
    L = len(n)-1
    W = [np.random.rand(n[i+1],n[i]+1) for i in range(L)]
    rep = representation(W)
    rep.QRDecomposition()
    if verbose:
        for i,w in enumerate(W):
            print("W[" + str(i) + "] has shape: " + str(w.shape))
        rep.printShapes()
    return rep

In [6]:
demo = testRepresentation(verbose=True)

W[0] has shape: (5, 4)
W[1] has shape: (8, 6)
W[2] has shape: (8, 9)
W[3] has shape: (2, 9)
Q    shapes: [(5, 5), (8, 8), (8, 8)]
R    shapes: [(4, 4), (5, 5), (6, 6), (2, 7)]
U    shapes: [(5, 4), (8, 6), (8, 9), (2, 9)]


In [7]:
for q in demo.Q:
    print(np.round(q,3), q.shape, end="\n \n")

[[-0.32  -0.176 -0.167  0.909 -0.112]
 [-0.472 -0.297 -0.12  -0.338 -0.749]
 [-0.534 -0.481 -0.124 -0.223  0.647]
 [-0.087  0.45  -0.881 -0.097  0.061]
 [-0.619  0.668  0.408 -0.006  0.062]] (5, 5)
 
[[-0.393  0.078  0.384  0.185 -0.488 -0.179 -0.365 -0.505]
 [-0.226 -0.669 -0.063 -0.38   0.246  0.015 -0.54   0.032]
 [-0.431  0.067 -0.403 -0.111  0.076 -0.757  0.236  0.024]
 [-0.381  0.107 -0.024 -0.372  0.271  0.427  0.395 -0.541]
 [-0.01  -0.47  -0.071 -0.226 -0.707  0.106  0.443  0.124]
 [-0.391 -0.012 -0.601  0.522 -0.107  0.411 -0.15   0.103]
 [-0.208 -0.465  0.423  0.548  0.331 -0.074  0.381  0.043]
 [-0.517  0.303  0.376 -0.219 -0.038  0.163 -0.023  0.651]] (8, 8)
 
[[-0.163  0.246 -0.168  0.145  0.67   0.064  0.395 -0.505]
 [-0.553 -0.491  0.277  0.16  -0.13  -0.043  0.556  0.15 ]
 [-0.156  0.416 -0.553 -0.127 -0.095  0.129  0.359  0.571]
 [-0.519 -0.138 -0.434 -0.46  -0.226 -0.188 -0.248 -0.404]
 [-0.366 -0.013 -0.238  0.72  -0.001  0.269 -0.462  0.069]
 [-0.333  0.266  0.272 

In [8]:
for r in demo.R:
    print(np.round(r,3), r.shape, end="\n \n")

[[-1.441 -0.734 -0.3   -0.826]
 [ 0.     0.388  0.34   0.212]
 [ 0.     0.    -0.829 -0.608]
 [ 0.     0.     0.     0.55 ]] (4, 4)
 
[[-1.847  2.503 -0.752  2.058 -0.622]
 [ 0.     1.462 -0.301  0.479 -0.118]
 [ 0.     0.     0.68   0.305 -0.28 ]
 [ 0.     0.     0.    -0.272 -0.263]
 [ 0.     0.     0.     0.    -0.732]] (5, 5)
 
[[-1.13   3.446  1.511  0.02   0.489  0.364]
 [ 0.    -2.016 -0.613 -0.357 -0.094  0.163]
 [ 0.     0.     0.908  0.254  0.089  0.136]
 [ 0.     0.     0.     0.384 -0.296  0.272]
 [ 0.     0.     0.     0.    -0.292  0.425]
 [ 0.     0.     0.     0.     0.     0.414]] (6, 6)
 
[[ 0.515 -1.631  0.412 -0.179  0.603  0.181 -0.198]
 [ 0.491 -1.578  0.882 -0.054  0.298 -0.187  0.565]] (2, 7)
 
