In [1]:
import numpy as np
from numpy.linalg import qr
from sklearn.metrics import mean_squared_error

In [2]:
def track_W(old_x, old_w, old_d, lamda):
    y = np.dot(old_w.T , old_x)
    d = lamda*old_d + y**2
    e = old_x - old_w*y
    w = old_w + e*y/d
    x = old_x - w*y
    w = w/(np.linalg.norm(w))
    return w, d, x

In [3]:
def SPIRIT(xt, t, E, E_est, W, d, k, lamda=1, fE=0.95, FE=0.98):
    x = xt.copy()
    for i in range(k):
        W[:,i], d[i], x = track_W(x, W[:,i], d[i], lamda)
#     print(f'W[:,0:k].shape {W[:,0:k].shape}')
    
    #need to ensure that W[:,0:k] are orthogonal matrix
    Q, R = qr(W[:,0:k])
    W[:,0:k] = Q
#     print(f'Q.shape {Q.shape}')
#     print(f'R.shape {R.shape}')
#     print(f'W[:,0:k].shape {W[:,0:k].shape}')
    
    # compute hidden variables
    Y = np.dot(W[:,0:k].T, xt)
#     print(f'Y.shape {Y.shape}') 
    
    # compute the reconstruction 
    xt_estimate = np.dot(W[:,0:k],Y)
#     print(f'xt_estimate.shape {xt_estimate.shape}')
    
    #update energy
    E = (lamda*(t-1)*E + np.sum(xt**2)) / t
#     print(f'E {E}')
    E_est = (lamda*(t-1)*E_est + np.sum(Y**2)) / t
#     print(f'E_est {E_est}')
    
    #determine whether to update the number of hidden variables
    if E_est < (fE*E) and k < xt.shape[0]:
        k = k + 1
#         print(f'Updating k=k+1 = {k}')
    elif E_est > (FE*E) and k > 1:
        k = k - 1
#         print(f'Updating k=k-1 = {k}')
    return W, d, k, xt_estimate

In [4]:
## generate 3 dimension data set of 10 points
np.random.seed(1)
rng = np.random.RandomState(1999)
A = rng.randn(10, 3) + 2
# initialize SPIRIT parameters
W = np.identity(A.shape[1])
k = 1
E = 0
E_est = 0
t = 1
d = 0.01 * np.ones(A.shape[1])
fE = 0.95
FE = 0.98
lamda = 1

In [5]:
A_estimate = None
for i in range(len(A)):
    xt = A[i,:]
    W, d, k, xt_estimate = SPIRIT(xt, t, E, E_est, W, d, k, lamda, fE, FE)
    if A_estimate is None:
        A_estimate = xt_estimate
    else:
        A_estimate = np.vstack((A_estimate, xt_estimate))

In [6]:
print('A')
print(A)
print('A_estimate')
print(A_estimate)

A
[[1.68251986 2.69206233 0.71562236]
 [2.39334583 2.21180288 1.58854318]
 [2.94547341 2.90643341 1.45884194]
 [3.20463807 1.92962428 2.70887691]
 [0.96984499 1.03195449 3.3538902 ]
 [2.80426695 1.64460456 1.19536682]
 [2.2470435  1.8595206  1.9274096 ]
 [0.68802366 0.16304126 1.75841876]
 [1.16916187 0.6220643  1.90347192]
 [2.09439562 1.49333031 2.62422045]]
A_estimate
[[1.68686635 2.68951613 0.71494551]
 [2.31313436 2.37420956 1.45064021]
 [2.88546924 2.89514024 1.58842647]
 [3.07342868 2.60649241 2.04256242]
 [1.90970857 1.64833589 1.64122764]
 [2.81524762 1.60271317 1.22494098]
 [2.31533224 1.87896589 1.82048878]
 [0.97654663 0.78251744 0.79857392]
 [1.1319285  0.65926073 1.91268586]
 [2.32488633 1.81634467 2.0861019 ]]


In [7]:
mse = mean_squared_error(A, A_estimate)
mse

0.23485365448683496

In [8]:
mse = (np.square(A - A_estimate)).mean(axis=1)
mse

array([8.61107162e-06, 1.72756806e-02, 6.84006233e-03, 3.06447122e-01,
       1.39882754e+00, 9.16698238e-04, 5.49117781e-03, 4.62766114e-01,
       9.51598140e-04, 1.49011941e-01])