In [1]:
import numpy as np
import pandas as pd
from numpy.linalg import eig
from numpy.linalg import svd
from numpy.linalg import qr

In [2]:
def track_W(old_x, old_w, old_d, lamda):
    y = np.dot(old_w.T , old_x)
    d = lamda*old_d + y**2
    e = old_x - old_w*y
    w = old_w + e*y/d
    x = old_x - w*y
    w = w/(np.linalg.norm(w))
    return w, d, x

In [3]:
def SPIRIT(xt, t, E, E_est, W, d, k, lamda=1, fE=0.95, FE=0.98):
    x = xt
    for i in range(k):
        W[:,i], d[i], x = track_W(x, W[:,i], d[i], lamda)
    print(f'W[:,0:k].shape {W[:,0:k].shape}')
    
    #need to ensure that W[:,0:k] are orthogonal matrix
    Q, R = qr(W[:,0:k])
    W[:,0:k] = Q
    print(f'Q.shape {Q.shape}')
    print(f'R.shape {R.shape}')
    print(f'W[:,0:k].shape {W[:,0:k].shape}')
    
    # compute hidden variables
    Y = np.dot(W[:,0:k].T, xt)
    print(f'Y.shape {Y.shape}') 
    
    # compute the reconstruction 
    xt_estimate = W[:,0:k] * Y
    print(f'xt_estimate.shape {xt_estimate.shape}')
    
    #update energy
    E = (lamda*(t-1)*E + np.sum(xt**2)) / t
    print(f'E {E}')
    E_est = (lamda*(t-1)*E_est + np.sum(Y**2)) / t
    print(f'E_est {E_est}')
    
    #determine whether to update the number of hidden variables
    if E_est < (fE*E) and k < xt.shape[0]:
        k = k + 1
        print(f'Updating k=k+1 = {k}')
    elif E_est > (FE*E) and k > 1:
        k = k - 1
        print(f'Updating k=k-1 = {k}')
    return W, d, k, xt_estimate.flatten()

In [4]:
rng = np.random.RandomState(1999)
A = rng.randn(5, 3) + 2
xt = A[0,:]
W = np.identity(xt.shape[0])
k = 1
E = 0
E_est = 0
t = 1
d = 0.01 * np.ones(xt.shape[0])
W, d, k, xt_estimate = SPIRIT(xt, t, E, E_est, W, d, k, lamda=1, fE=0.95, FE=0.98)
print(f'xt {xt}')
print(f'xt_estimate {xt_estimate}')

W[:,0:k].shape (3, 1)
Q.shape (3, 1)
R.shape (1, 1)
W[:,0:k].shape (3, 1)
Y.shape (1,)
xt_estimate.shape (3, 1)
E 10.590188021422057
E_est 10.590162188207207
xt [1.68251986 2.69206233 0.71562236]
xt_estimate [1.68686635 2.68951613 0.71494551]


In [5]:
xt = A[1,:]
t = 2
W, d, k, xt_estimate = SPIRIT(xt, t, E, E_est, W, d, k, lamda=1, fE=0.95, FE=0.98)
print(f'xt {xt}')
print(f'xt_estimate {xt_estimate}')

W[:,0:k].shape (3, 1)
Q.shape (3, 1)
R.shape (1, 1)
W[:,0:k].shape (3, 1)
Y.shape (1,)
xt_estimate.shape (3, 1)
E 6.571822826698817
E_est 6.545909305844012
xt [2.39334583 2.21180288 1.58854318]
xt_estimate [2.31313436 2.37420956 1.45064021]


In [6]:
xt = A[2,:]
t = 3
W, d, k, xt_estimate = SPIRIT(xt, t, E, E_est, W, d, k, lamda=1, fE=0.95, FE=0.98)
print(f'xt {xt}')
print(f'xt_estimate {xt_estimate}')

W[:,0:k].shape (3, 1)
Q.shape (3, 1)
R.shape (1, 1)
W[:,0:k].shape (3, 1)
Y.shape (1,)
xt_estimate.shape (3, 1)
E 6.417129533602622
E_est 6.410289471273079
xt [2.94547341 2.90643341 1.45884194]
xt_estimate [2.88546924 2.89514024 1.58842647]


In [7]:
W

array([[-0.65798653,  0.        ,  0.        ],
       [-0.66019185,  1.        ,  0.        ],
       [-0.36221603,  0.        ,  1.        ]])