In [1]:
import numpy as np
import time
import math
from numpy.linalg import qr
from sklearn.metrics import mean_squared_error

In [2]:
def track_W(old_x, old_w, old_d, lamda):
    y = np.matmul(old_w.T , old_x)
    d = lamda*old_d + y**2
    e = old_x - old_w*y
    w = old_w + e*y/d
    x = old_x - w*y
    w = w/(np.linalg.norm(w))
    return w, d, x

In [3]:
def grams(A):
    n = A.shape[1]
    Q = A
    Asave = A;
    for j in range(n):
      for k in range(j-1):
        mult = np.divide(np.matmul(A[:,j].T, A[:,k]), np.matmul(A[:,k].T, A[:,k]))
        A[:,j] = A[:,j] - np.dot(mult, A[:,k])
    
    for j in range(n):
      if np.linalg.norm(A[:,j]) < math.sqrt(2.2204e-16):
        print('Columns of A are linearly dependent.')
      Q[:,j] = A[:,j] / np.linalg.norm(A[:,j])
    
    R = np.dot(Q.T, Asave)
    return Q, R

In [4]:
def SPIRIT(xt, t, E, E_est, W, d, k, lamda=1, fE=0.95, FE=0.98, profiling=False):
    if profiling:
        start = time.perf_counter()
    
    x = xt.copy()
    for i in range(k):
        W[:,i], d[i], x = track_W(x, W[:,i], d[i], lamda)
#     print(f'W[:,0:k].shape {W[:,0:k].shape}')
    
    #need to ensure that W[:,0:k] are orthogonal matrix
    W[:,0:k], _ = qr(W[:,0:k])
#     print(f'Q.shape {Q.shape}')
#     print(f'R.shape {R.shape}')
#     print(f'W[:,0:k].shape {W[:,0:k].shape}')
    
    # compute hidden variables
    Y = np.dot(W[:,0:k].T, xt)
#     print(f'Y.shape {Y.shape}') 
    
    # compute the reconstruction 
    xt_estimate = np.dot(W[:,0:k],Y)
#     print(f'xt_estimate.shape {xt_estimate.shape}')
    
    #update energy
    E = (lamda*(t-1)*E + np.sum(xt**2)) / t
#     print(f'E {E}')
    E_est = (lamda*(t-1)*E_est + np.sum(Y**2)) / t
#     print(f'E_est {E_est}')
    
    #determine whether to update the number of hidden variables
    if E_est < (fE*E) and k < xt.shape[0]:
        k = k + 1
#         print(f'Updating k=k+1 = {k}')
    elif E_est > (FE*E) and k > 1:
        k = k - 1
#         print(f'Updating k=k-1 = {k}')

    if profiling:
        duration = time.perf_counter() - start
        return W, d, k, xt_estimate, duration
        
    return W, d, k, xt_estimate

In [5]:
## generate 3 dimension data set of 10 points
np.random.seed(1)
rng = np.random.RandomState(1999)
A = rng.randn(10, 3) + 2
# initialize SPIRIT parameters
W = np.identity(A.shape[1])
k = 1
E = 0
E_est = 0
t = 1
d = 0.01 * np.ones((A.shape[1],1))
fE = 0.95
FE = 0.98
lamda = 1

In [6]:
A_estimate = None
duration = 0
for i in range(len(A)):
    xt = A[i,:].T
    W, d, k, xt_estimate, runtime = SPIRIT(xt, t, E, E_est, W, d, k, lamda, fE, FE, profiling=True)
    duration += runtime
    if A_estimate is None:
        A_estimate = xt_estimate
    else:
        A_estimate = np.vstack((A_estimate, xt_estimate))

In [7]:
mse = mean_squared_error(A, A_estimate)

In [8]:
def spirit_all(A, lamda, energy, k0=None, holdOffTime=None):
    if not holdOffTime:
        holdOffTime = 10
    
    if not k0:
        k0=3
    
    n = A.shape[1] ## dimension
    totalTime = A.shape[0]
    Proj = np.zeros((totalTime, n))
    recon = np.zeros((totalTime, n))
    W = np.identity(n)
    d = 0.01 * np.ones((A.shape[1],1))
    m = k0
    relErrors = np.zeros((totalTime, 1))

    sumYSq=0;
    sumXSq=0;

    lastChangeAt = 1;
    for t in range(totalTime):
        x = A[t,:].T
        for j in range(m):
            W[:,j], d[j], x = track_W(x, W[:,j], d[j], lamda) 
        W[:,0:m], _ = qr(W[:,0:m])
        Y = np.dot(W[:,0:m].T, A[t,:].T)

        xActual = A[t,:].T
        xProj = np.dot(W[:,0:m], Y)
        Proj[t,0:m] = Y
#         from IPython.core.debugger import set_trace
#         set_trace()
        recon[t,:] = xProj
        xOrth = xActual - xProj
        relErrors[t] = np.sum(xOrth**2)/np.sum(xActual**2)
        
        sumYSq = lamda * sumYSq + np.sum(Y**2)
        sumXSq = lamda * sumXSq + np.sum(A[t,:]**2)
        
        if(sumYSq < energy[0]*sumXSq and lastChangeAt < t - holdOffTime and m < n):
            lastChangeAt = t
            m = m+1
        elif (sumYSq > energy[1]*sumXSq and lastChangeAt < t - holdOffTime and m < n and m>1):
            lastChangeAt = t
            m = m-1;
        
    W[:,0:m], _ = qr(W[:,0:m])
    W = W[:,0:m]
    k = m
    errs = relErrors
    return W, m, Proj, recon

In [9]:
# np.random.seed(1)
# rng = np.random.RandomState(1999)
# A = rng.randn(10000, 100) + 2
# start = time.perf_counter()
# W, m, Proj, recon = spirit_all(A, lamda=1, energy=(0.95,0.98), k0=1)
# runtime = time.perf_counter() - start
# runtime

In [10]:
# mse = mean_squared_error(A, recon)
# mse