In [None]:
import numpy as np
import time
from numpy.linalg import qr
from sklearn.metrics import mean_squared_error

In [None]:
def track_W(old_x, old_w, old_d, lamda):
    y = np.dot(old_w.T , old_x)
    d = lamda*old_d + y**2
    e = old_x - old_w*y
    w = old_w + e*y/d
    x = old_x - w*y
    w = w/(np.linalg.norm(w))
    return w, d, x

In [None]:
def SPIRIT(xt, t, E, E_est, W, d, k, lamda=1, fE=0.95, FE=0.98, profiling=False):
    if profiling:
        start = time.perf_counter()
    
    x = xt.copy()
    for i in range(k):
        W[:,i], d[i], x = track_W(x, W[:,i], d[i], lamda)
#     print(f'W[:,0:k].shape {W[:,0:k].shape}')
    
    #need to ensure that W[:,0:k] are orthogonal matrix
    Q, R = qr(W[:,0:k])
    W[:,0:k] = Q
#     print(f'Q.shape {Q.shape}')
#     print(f'R.shape {R.shape}')
#     print(f'W[:,0:k].shape {W[:,0:k].shape}')
    
    # compute hidden variables
    Y = np.dot(W[:,0:k].T, xt)
#     print(f'Y.shape {Y.shape}') 
    
    # compute the reconstruction 
    xt_estimate = np.dot(W[:,0:k],Y)
#     print(f'xt_estimate.shape {xt_estimate.shape}')
    
    #update energy
    E = (lamda*(t-1)*E + np.sum(xt**2)) / t
#     print(f'E {E}')
    E_est = (lamda*(t-1)*E_est + np.sum(Y**2)) / t
#     print(f'E_est {E_est}')
    
    #determine whether to update the number of hidden variables
    if E_est < (fE*E) and k < xt.shape[0]:
        k = k + 1
#         print(f'Updating k=k+1 = {k}')
    elif E_est > (FE*E) and k > 1:
        k = k - 1
#         print(f'Updating k=k-1 = {k}')

    if profiling:
        duration = time.perf_counter() - start
        return W, d, k, xt_estimate, duration
        
    return W, d, k, xt_estimate

In [None]:
## generate 3 dimension data set of 10 points
np.random.seed(1)
rng = np.random.RandomState(1999)
A = rng.randn(10, 3) + 2
# initialize SPIRIT parameters
W = np.identity(A.shape[1])
k = 1
E = 0
E_est = 0
t = 1
d = 0.01 * np.ones(A.shape[1])
fE = 0.95
FE = 0.98
lamda = 1

In [None]:
A_estimate = None
for i in range(len(A)):
    xt = A[i,:]
    W, d, k, xt_estimate = SPIRIT(xt, t, E, E_est, W, d, k, lamda, fE, FE)
    if A_estimate is None:
        A_estimate = xt_estimate
    else:
        A_estimate = np.vstack((A_estimate, xt_estimate))

In [None]:
# print('A')
# print(A)
# print('A_estimate')
# print(A_estimate)

In [None]:
mse = mean_squared_error(A, A_estimate)

In [None]:
mse = (np.square(A - A_estimate)).mean(axis=1)