In [27]:
import numpy as np
from numpy.linalg import inv

from scipy.spatial.distance import pdist, squareform
from scipy.sparse.csgraph import laplacian
from scipy.sparse import coo_matrix, csr_matrix
from scipy.sparse.linalg import eigs, eigsh
from scipy.stats.distributions import chi2

from sklearn.neighbors import NearestNeighbors

from matplotlib import pyplot as plt
plt.rcParams['text.latex.preamble']=[r"\usepackage{xcolor}"]
plt.rcParams['scatter.marker']='.'
from matplotlib.widgets import Slider
from mpl_toolkits.mplot3d import Axes3D

import pdb

SMALL_SIZE = 14
MEDIUM_SIZE = 16
BIGGER_SIZE = 18

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


def eval_param(phi, Psi_gamma, Psi_i, k, mask, beta=None):
    if beta is None:
        return Psi_gamma[k,:][np.newaxis,:] * phi[np.ix_(mask,Psi_i[k,:])]
    else:
        return beta[k]*Psi_gamma[k,:][np.newaxis,:] * phi[np.ix_(mask,Psi_i[k,:])]
    
def compute_zeta(d_e_mask, Psi_k_mask):
    if d_e_mask.shape[0]==1:
        return 1
    disc_lip_const = pdist(Psi_k_mask)/squareform(d_e_mask)
    return np.max(disc_lip_const)/np.min(disc_lip_const)

def cost_of_moving_node(k, d_e, U, phi, Psi_gamma, Psi_i, c,
                        n_C, Utilde, eta_min, eta_max, beta=None):
    c_k = c[k]
    n_C_c_k = n_C[c_k]
    if n_C_c_k >= eta_min:
        return np.inf, -1
    
    n = n_C.shape[0]
    
    is_visited = np.zeros(n, dtype='int')
    is_visited[c_k] = 1
    cost_x_k_to = np.zeros(n) + np.inf
    
    U_k = U[k,:]==1
    c_U_k = c[U_k].tolist()
    
    for m in c_U_k:
        n_C_m = n_C[m]
        if is_visited[m] or (n_C_m >= eta_max):
            is_visited[m] = 1
            continue
            
        if n_C_m >= n_C_c_k:
            Utilde_m = Utilde[m,:]==1
            U_k_U_Utilde_m = U_k | Utilde_m
            cost_x_k_to[m] = compute_zeta(d_e[np.ix_(U_k_U_Utilde_m,U_k_U_Utilde_m)],
                                  eval_param(phi, Psi_gamma, Psi_i, m, U_k_U_Utilde_m, beta))
        
        is_visited[m] = 1
    
    dest_k = np.argmin(cost_x_k_to)
    cost_k = cost_x_k_to[dest_k]
    if cost_k == np.inf:
        dest_k = -1
    return cost_k, dest_k
    

class Visualize:
    def __init__(self):
        pass
    
    def data(self, X, labels, figsize=None, s=20):
        assert X.shape[1] <= 3, 'X.shape[1] must be either 2 or 3.'
        fig = plt.figure(figsize=figsize)
        if X.shape[1] == 2:
            plt.scatter(X[:,0], X[:,1], s=s, c=labels, cmap='jet')
            plt.axis('image')
        elif X.shape[1] == 3:
            ax = fig.add_subplot(projection='3d')
            ax.autoscale()
            ax.scatter(X[:,0], X[:,1], X[:,2], s=s, c=labels, cmap='jet')
        plt.title('Data')
        
    def eigenvalues(self, lmbda, figsize=None):
        fig = plt.figure(figsize=figsize)
        plt.plot(lmbda, 'o-')
        plt.ylabel('$\lambda_i$')
        plt.xlabel('i')
        plt.title('Eigenvalues')
        
    def eigenvector(self, X, phi, i, figsize=None, s=20):
        assert X.shape[1] <= 3, 'X.shape[1] must be either 2 or 3.'
        fig = plt.figure(figsize=figsize)
        if X.shape[1] == 2:
            plt.scatter(X[:,0], X[:,1], s=s, c=phi[:,i], cmap='jet')
            plt.axis('image')
            plt.colorbar()
        elif X.shape[1] == 3:
            ax = fig.add_subplot(projection='3d')
            ax.autoscale()
            #ax.set_aspect('equal')
            p = ax.scatter(X[:,0], X[:,1], X[:,2], s=s, c=phi[:,i], cmap='jet')
            fig.colorbar(p)
        plt.title('$\phi_{%d}$'%i)
    
    def distortion(self, X, zeta, figsize=None, s=20):
        assert X.shape[1] <= 3, 'X.shape[1] must be either 2 or 3.'
        fig = plt.figure(figsize=figsize)
        if X.shape[1] == 2:
            plt.scatter(X[:,0], X[:,1], s=s, c=zeta, cmap='jet')
            plt.axis('image')
            plt.colorbar()
        elif X.shape[1] == 3:
            ax = fig.add_subplot(projection='3d')
            ax.autoscale()
            p = ax.scatter(X[:,0], X[:,1], X[:,2], s=s, c=zeta, cmap='jet')
            fig.colorbar(p)
        plt.title('Data')
    
    def chosen_eigevec_inds_for_local_views(self, X, Psi_i, figsize=None, s=20):
        assert X.shape[1] <= 3, 'X.shape[1] must be either 2 or 3.'
        fig = plt.figure(figsize=figsize)
        if X.shape[1] == 2:
            plt.subplot(211)
            plt.scatter(X[:,0], X[:,1], s=s, c=Psi_i[:,0], cmap='jet')
            plt.axis('image')
            plt.colorbar()
            plt.title('\\phi_{i_1}')
            plt.subplot(212)
            plt.scatter(X[:,0], X[:,1], s=s, c=Psi_i[:,1], cmap='jet')
            plt.axis('image')
            plt.colorbar()
            plt.title('\\phi_{i_2}')
        elif X.shape[1] == 3:
            ax = fig.add_subplot(211, projection='3d')
            ax.autoscale()
            p = ax.scatter(X[:,0], X[:,1], X[:,2], s=s, c=Psi_i[:,0], cmap='jet')
            fig.colorbar(p)
            ax.set_title('\\phi_{i_1}')
            ax = fig.add_subplot(212, projection='3d')
            ax.autoscale()
            p = ax.scatter(X[:,0], X[:,1], X[:,2], s=s, c=Psi_i[:,1], cmap='jet')
            fig.colorbar(p)
            ax.set_title('\\phi_{i_2}')
    
    def chosen_eigevec_inds_for_intermediate_views(self, X, Psitilde_i, c, figsize=None, s=20):
        assert X.shape[1] <= 3, 'X.shape[1] must be either 2 or 3.'
        fig = plt.figure(figsize=figsize)
        if X.shape[1] == 2:
            plt.subplot(211)
            plt.scatter(X[:,0], X[:,1], s=s, c=Psitilde_i[c,0], cmap='jet')
            plt.axis('image')
            plt.colorbar()
            plt.title('\\phi_{i_1}')
            plt.subplot(212)
            plt.scatter(X[:,0], X[:,1], s=s, c=Psitilde_i[c,1], cmap='jet')
            plt.axis('image')
            plt.colorbar()
            plt.title('\\phi_{i_2}')
        elif X.shape[1] == 3:
            ax = fig.add_subplot(211, projection='3d')
            ax.autoscale()
            p = ax.scatter(X[:,0], X[:,1], X[:,2], s=s, c=Psitilde_i[c,0], cmap='jet')
            fig.colorbar(p)
            plt.title('\\phi_{i_1}')
            ax = fig.add_subplot(212, projection='3d')
            ax.autoscale()
            p = ax.scatter(X[:,0], X[:,1], X[:,2], s=s, c=Psitilde_i[c,1], cmap='jet')
            fig.colorbar(p)
            ax.set_title('\\phi_{i_2}')
    
    def local_views(self, X, phi, U, gamma, Atilde, Psi_gamma, Psi_i, zeta, figsize=None, s=20):
        assert X.shape[1] <= 3, 'X.shape[1] must be either 2 or 3.'
        is_3d_data = X.shape[1] == 3
        n,N = phi.shape
        
        fig = plt.figure(1, figsize=figsize)
        figManager = plt.get_current_fig_manager()
        figManager.window.showMaximized()
        
        cb = [None, None, None]
        if is_3d_data:
            ax = [None, None, None]
            ax[0] = fig.add_subplot(231, projection='3d')
            ax[1] = fig.add_subplot(232, projection='3d')
            ax[2] = fig.add_subplot(233, projection='3d')
            p = ax[0].scatter(X[:,0], X[:,1], X[:,2], s=s, c=zeta, cmap='jet')
            cb[0] = plt.colorbar(p, ax=ax[0])
            cb[1] = plt.colorbar(p, ax=ax[1])
            cb[2] = plt.colorbar(p, ax=ax[2])
        else:
            plt.subplot(231)
            plt.scatter(X[:,0], X[:,1], s=s, c=zeta, cmap='jet')
            cb[0] = plt.colorbar()
            plt.axis('image')
            plt.subplot(232)
            cb[1] = plt.colorbar()
            plt.subplot(233)
            cb[2] = plt.colorbar()
        
        while True:
            plt.figure(1, figsize=figsize)
            to_exit = plt.waitforbuttonpress(timeout=20)
            if to_exit:
                plt.close()
                return
            # Plot data with distortion colormap and the
            # selected local view in the ambient space
            if is_3d_data:
                ax[0]
                plt.ginput(1)
                k = np.random.randint(n)
            else:
                plt.subplot(231)
                X_k = plt.ginput(1)
                X_k = np.array(X_k)
                k = np.argmin(np.sum((X-X_k)**2,1))
                
            U_k = U[k,:]==1
            
            if is_3d_data:
                ax[0].cla()
                cb[0].remove()
                p = ax[0].scatter(X[:,0], X[:,1], X[:,2], s=s*(1-U_k), c=zeta, cmap='jet')
                cb[0] = plt.colorbar(p, ax=ax[0])
                ax[0].scatter(X[U_k,0], X[U_k,1], X[U_k,2], s=s, c='k')
                ax[0].set_title('$\\mathcal{M}$ and $U_{%d}$' % k)
            else:
                plt.cla()
                cb[0].remove()
                plt.scatter(X[:,0], X[:,1], s=s*(1-U_k), c=zeta, cmap='jet')
                cb[0] = plt.colorbar()
                plt.scatter(X[U_k,0], X[U_k,1], s=s, c='k')
                plt.axis('image')
                plt.title('$\\mathcal{M}$ and $U_{%d}$' % k)
            
            # Plot the corresponding local view in the embedding space
            y = eval_param(phi, Psi_gamma, Psi_i, k, np.ones(n)==1)
            plt.subplot(234)
            plt.cla()
            plt.scatter(y[:,0], y[:,1], s=s, c='r')
            plt.scatter(y[U_k,0], y[U_k,1], s=s, c='k')
            plt.axis('image')
            plt.title('$\\zeta_{%d%d}=%.3f\\'\
                      ' \\Phi_{%d}(\\mathcal{M})$ in red and $\\Phi_{%d}(U_{%d})$ in black'\
                      % (k, k, zeta[k], k, k, k))
            
            # Plot the chosen eigenvectors and scaled eigenvectors
            subplots = [232, 233]
            for j in range(len(subplots)):
                i_s = Psi_i[k,j]
                if is_3d_data:
                    ax[j+1].cla()
                    cb[j+1].remove()
                    p = ax[j+1].scatter(X[:,0], X[:,1], X[:,2], s=s*(1-U_k), c=phi[:,i_s], cmap='jet')
                    cb[j+1] = plt.colorbar(p, ax=ax[j+1])
                    ax[j+1].scatter(X[U_k,0], X[U_k,1], X[U_k,2], s=s, c='k')
                    ax[j+1].set_title('$\\phi_{%d}$' % i_s)
                else:
                    plt.subplot(subplots[j])
                    plt.cla()
                    cb[j+1].remove()
                    plt.scatter(X[:,0], X[:,1], s=s*(1-U_k), c=phi[:,i_s], cmap='jet')
                    cb[j+1]=plt.colorbar()
                    plt.scatter(X[U_k,0], X[U_k,1], s=s, c='k')
                    plt.axis('image')
                    plt.title('$\\phi_{%d}$' % i_s)
            
            Atilde_k = np.abs(Atilde[k,:,:])
            Atilde_kii = np.sqrt(Atilde_k.diagonal()[:,np.newaxis])
            angles = (Atilde_k/Atilde_kii)/(Atilde_kii.T)
            
            prctiles = np.arange(100)
            plt.subplot(235)
            plt.cla()
            plt.plot(prctiles, np.percentile(angles.flatten(), prctiles), 'bo-')
            plt.plot([0,100], [0,0], 'g-')
            plt.plot([0,100], [angles[Psi_i[k,0],Psi_i[k,1]]]*2, 'r-')
            plt.xlabel('percentiles')
            plt.title('$|\\widetilde{A}_{%dij}|/(\\widetilde{A}_{%dii}\\widetilde{A}_{%djj})$' % (k, k, k))
            
            
            local_scales = gamma[[k],:].T*Atilde_kii
            #dlocal_scales = squareform(pdist(local_scales))
            dlocal_scales = np.log(local_scales/local_scales.T+1)
            
            plt.subplot(236)
            plt.cla()
            plt.plot(prctiles, np.percentile(dlocal_scales.flatten(), prctiles), 'bo-')
            plt.plot([0,100], [np.log(2)]*2, 'g-')
            plt.plot([0,100], [dlocal_scales[Psi_i[k,0],Psi_i[k,1]]]*2, 'r-')
            plt.xlabel('percentiles')
            plt.title('$\\log(\\gamma_{%di}\\sqrt{\\widetilde{A}_{%dii}} / \
                      \\gamma_{%dj}\\sqrt{\\widetilde{A}_{%djj}}+1)$' % (k, k, k, k))
                
            plt.show()
            fig.canvas.draw()
            fig.canvas.flush_events()
    
    def intermediate_views(self, X, phi, Utilde, gamma, Atilde, Psitilde_gamma,
              Psitilde_i, zetatilde, c, figsize=None, s=20):
        assert X.shape[1] <= 3, 'X.shape[1] must be either 2 or 3.'
        is_3d_data = X.shape[1] == 3
        n,N = phi.shape

        zeta = zetatilde[c]
        
        fig = plt.figure(1, figsize=figsize)
        figManager = plt.get_current_fig_manager()
        figManager.window.showMaximized()
        
        cb = [None, None, None]
        if is_3d_data:
            ax = [None, None, None]
            ax[0] = fig.add_subplot(231, projection='3d')
            ax[1] = fig.add_subplot(232, projection='3d')
            ax[2] = fig.add_subplot(233, projection='3d')
            p = ax[0].scatter(X[:,0], X[:,1], X[:,2], s=s, c=zeta, cmap='jet')
            cb[0] = plt.colorbar(p, ax=ax[0])
            cb[1] = plt.colorbar(p, ax=ax[1])
            cb[2] = plt.colorbar(p, ax=ax[2])
        else:
            plt.subplot(231)
            plt.scatter(X[:,0], X[:,1], s=s, c=zeta, cmap='jet')
            cb[0] = plt.colorbar()
            plt.axis('image')
            plt.subplot(232)
            cb[1] = plt.colorbar()
            plt.subplot(233)
            cb[2] = plt.colorbar()
        
        while True:
            plt.figure(1, figsize=figsize)
            to_exit = plt.waitforbuttonpress(timeout=20)
            if to_exit is None:
                print('Timed out')
                break
                
            if to_exit:
                plt.close()
                return
            
            # Plot data with distortion colormap and the
            # selected local view in the ambient space
            if is_3d_data:
                ax[0]
                plt.ginput(1)
                k = np.random.randint(n)
            else:
                plt.subplot(231)
                X_k = plt.ginput(1)
                X_k = np.array(X_k)
                k = np.argmin(np.sum((X-X_k)**2,1))

            m = c[k]
            Utilde_m = Utilde[m,:]
            
            if is_3d_data:
                ax[0].cla()
                cb[0].remove()
                p = ax[0].scatter(X[:,0], X[:,1], X[:,2], s=s*(1-Utilde_m), c=zeta, cmap='jet')
                cb[0] = plt.colorbar(p, ax=ax[0])
                ax[0].scatter(X[Utilde_m,0], X[Utilde_m,1], X[Utilde_m,2], s=s, c='k')
                ax[0].set_title('$\\mathcal{M}$ and $\\widetilde{U}_{%d}$' % m)
            else:
                plt.cla()
                cb[0].remove()
                plt.scatter(X[:,0], X[:,1], s=s*(1-Utilde_m), c=zeta, cmap='jet')
                cb[0] = plt.colorbar()
                plt.scatter(X[Utilde_m,0], X[Utilde_m,1], s=s, c='k')
                plt.axis('image')
                plt.title('$\\mathcal{M}$ and $\\widetilde{U}_{%d}$' % m)
            
            # Plot the corresponding local view in the embedding space
            y = eval_param(phi, Psitilde_gamma, Psitilde_i, m, np.ones(n)==1)
            plt.subplot(234)
            plt.cla()
            plt.scatter(y[:,0], y[:,1], s=s, c='r')
            plt.scatter(y[Utilde_m,0], y[Utilde_m,1], s=s, c='k')
            plt.axis('image')
            plt.title('$\\widetilde{\\zeta}_{%d%d}=%.3f\\'\
                      ' \\widetilde{\\Phi}_{%d}(\\mathcal{M})$ in red'\
                      ' and $\\widetilde{\\Phi}_{%d}(\\widetilde{U}_{%d})$ in black'\
                      % (m, m, zetatilde[m], m, m, m)) # zetatilde[m] == zeta[k]
            
            # Plot the chosen eigenvectors and scaled eigenvectors
            subplots = [232, 233]
            for j in range(len(subplots)):
                i_s = Psitilde_i[m,j]
                if is_3d_data:
                    ax[j+1].cla()
                    cb[j+1].remove()
                    p = ax[j+1].scatter(X[:,0], X[:,1], X[:,2], s=s*(1-Utilde_m), c=phi[:,i_s], cmap='jet')
                    cb[j+1] = plt.colorbar(p, ax=ax[j+1])
                    ax[j+1].scatter(X[Utilde_m,0], X[Utilde_m,1], X[Utilde_m,2], s=s, c='k')
                    ax[j+1].set_title('$\\widetilde{\\phi}_{%d}$' % i_s)
                else:
                    plt.subplot(subplots[j])
                    plt.cla()
                    cb[j+1].remove()
                    plt.scatter(X[:,0], X[:,1], s=s*(1-Utilde_m), c=phi[:,i_s], cmap='jet')
                    cb[j+1] = plt.colorbar()
                    plt.scatter(X[Utilde_m,0], X[Utilde_m,1], s=s, c='k')
                    plt.axis('image')
                    plt.title('$\\widetilde{\\phi}_{%d}$' % i_s)
            
            Atilde_k = np.abs(Atilde[k,:,:])
            Atilde_kii = np.sqrt(Atilde_k.diagonal()[:,np.newaxis])
            angles = (Atilde_k/Atilde_kii)/(Atilde_kii.T)
            
            prctiles = np.arange(100)
            plt.subplot(235)
            plt.cla()
            plt.plot(prctiles, np.percentile(angles.flatten(), prctiles), 'bo-')
            plt.plot([0,100], [0,0], 'g-')
            plt.plot([0,100], [angles[Psitilde_i[m,0],Psitilde_i[m,1]]]*2, 'r-')
            plt.xlabel('percentiles')
            plt.title('$|\\widetilde{A}_{%dij}|/(\\widetilde{A}_{%dii}\\widetilde{A}_{%djj})$' % (k, k, k))
            
            
            local_scales = gamma[[k],:].T*Atilde_kii
            #dlocal_scales = squareform(pdist(local_scales))
            dlocal_scales = np.log(local_scales/local_scales.T+1)
            
            plt.subplot(236)
            plt.cla()
            plt.plot(prctiles, np.percentile(dlocal_scales.flatten(), prctiles), 'bo-')
            plt.plot([0,100], [np.log(2)]*2, 'g-')
            plt.plot([0,100], [dlocal_scales[Psitilde_i[m,0],Psitilde_i[m,1]]]*2, 'r-')
            plt.xlabel('percentiles')
            plt.title('$\\log(\\gamma_{%di}\\sqrt{\\widetilde{A}_{%dii}} / \
                      \\gamma_{%dj}\\sqrt{\\widetilde{A}_{%djj}}+1)$' % (k, k, k, k))
                
            plt.show()
            fig.canvas.draw()
            fig.canvas.flush_events()
        
    def compare_local_high_low_distortion(self, X, Atilde, Psi_gamma, Psi_i, zeta, figsize=None, s=20):
        assert X.shape[1] <= 3, 'X.shape[1] must be either 2 or 3.'
        is_3d_data = X.shape[1] == 3
        
        n = Atilde.shape[0]
        prctiles = np.arange(100)
        
        Atilde_ki_1i_2 = np.abs(Atilde[np.arange(n),Psi_i[:,0],Psi_i[:,1]])
        Atilde_ki_1i_1 = np.sqrt(Atilde[np.arange(n),Psi_i[:,0],Psi_i[:,0]])
        Atilde_ki_2i_2 = np.sqrt(Atilde[np.arange(n),Psi_i[:,1],Psi_i[:,1]])
        angles = (Atilde_ki_1i_2/Atilde_ki_1i_1)/(Atilde_ki_2i_2.T)
        
        local_scales_i_1 = Psi_gamma[np.arange(n),0].T*Atilde_ki_1i_1
        local_scales_i_2 = Psi_gamma[np.arange(n),1].T*Atilde_ki_2i_2
        dlocal_scales = np.log(local_scales_i_1/local_scales_i_2+1)
        
        fig2 = plt.figure(2, figsize=figsize)
        figManager = plt.get_current_fig_manager()
        figManager.window.showMaximized()
        
        if is_3d_data:
            ax = [None, None]
            ax[0] = fig2.add_subplot(321, projection='3d')
            ax[1] = fig2.add_subplot(323, projection='3d')
            p = ax[0].scatter(X[:,0], X[:,1], X[:,2], s=s, c=zeta, cmap='jet')
            cb = plt.colorbar(p, ax=ax[0])
            ax[0].autoscale()
            ax[0].set_title('$x_k$ colored by $\\zeta_{kk}$')
        else:
            plt.subplot(321)
            plt.scatter(X[:,0], X[:,1], s=s, c=zeta, cmap='jet')
            plt.colorbar()
            plt.axis('image')
            plt.title('$x_k$ colored by $\\zeta_{kk}$')
            
        plt.subplot(322)
        plt.cla()
        plt.plot(prctiles, np.percentile(zeta, prctiles), 'bo-')
        plt.xlabel('percentiles')
        plt.title('$\\zeta_{kk}$')
        plt.show()
        fig2.canvas.draw()
        fig2.canvas.flush_events()
        
        while True:
            plt.subplot(322)
            
            to_exit = plt.waitforbuttonpress(timeout=20)
            if to_exit is None:
                print('Timed out')
                break
                
            if to_exit:
                plt.close()
                return
            
            zeta_k = plt.ginput(1)
            print(zeta_k)
            thresh = zeta_k[0][1]
            
            plt.cla()
            plt.plot(prctiles, np.percentile(zeta, prctiles), 'bo-')
            plt.plot([0,100], [thresh]*2, 'r-')
            plt.xlabel('percentiles')
            plt.title('$\\zeta_{kk}$, threshold = %f' % thresh)
            plt.show()
            fig2.canvas.draw()
            fig2.canvas.flush_events()

            low_dist_mask = zeta <= thresh
            high_dist_mask = zeta >= thresh

            if is_3d_data:
                ax[1].cla()
                ax[1].scatter(X[low_dist_mask,0], X[low_dist_mask,1], X[low_dist_mask,2], s=s, c='b')
                ax[1].scatter(X[high_dist_mask,0], X[high_dist_mask,1], X[high_dist_mask,2], s=s, c='r')
                ax[1].autoscale()
                ax[1].set_title('blue = low distortion, Red = high distortion')
            else:
                plt.subplot(323)
                plt.cla()
                plt.scatter(X[low_dist_mask,0], X[low_dist_mask,1], s=s, c='b')
                plt.scatter(X[high_dist_mask,0], X[high_dist_mask,1], s=s, c='r')
                plt.title('blue = low distortion, Red = high distortion')
                plt.axis('image')

            plt.subplot(325)
            plt.cla() 
            plt.boxplot([angles[low_dist_mask],angles[high_dist_mask]],
                        labels=['low $\\zeta_{kk}$','high $\\zeta_{kk}$'], notch=True,
                               vert=False, patch_artist=True)
            plt.title('$|\\widetilde{A}_{ki_1i_2}|/(\\widetilde{A}_{ki_1i_1}\\widetilde{A}_{ki_2i_2})$')
            plt.show()

            plt.subplot(326)
            plt.cla()
            plt.boxplot([dlocal_scales[low_dist_mask],dlocal_scales[high_dist_mask]],
                        labels=['low $\\zeta_{kk}$','high $\\zeta_{kk}$'], notch=True,
                               vert=False, patch_artist=True)
            plt.title('$\\log(\\gamma_{ki_1}\\sqrt{\\widetilde{A}_{ki_1i_1}} / \
                      \\gamma_{ki_2}\\sqrt{\\widetilde{A}_{ki_2i_2}}+1)$')
            plt.show()
            fig2.canvas.draw()
            fig2.canvas.flush_events()
        
    def compare_intermediate_high_low_distortion(self, X, Atilde, Psitilde_gamma, 
                                                 Psitilde_i, zetatilde, c, figsize=None, s=20):
        assert X.shape[1] <= 3, 'X.shape[1] must be either 2 or 3.'
        is_3d_data = X.shape[1] == 3

        zeta = zetatilde[c]

        n = Atilde.shape[0]
        M = Psitilde_i.shape[0]
        prctiles = np.arange(100)

        Atilde_ki_1i_2 = np.abs(Atilde[np.arange(M),Psitilde_i[:,0],Psitilde_i[:,1]])
        Atilde_ki_1i_1 = np.sqrt(Atilde[np.arange(M),Psitilde_i[:,0],Psitilde_i[:,0]])
        Atilde_ki_2i_2 = np.sqrt(Atilde[np.arange(M),Psitilde_i[:,1],Psitilde_i[:,1]])
        angles = (Atilde_ki_1i_2/Atilde_ki_1i_1)/(Atilde_ki_2i_2.T)

        local_scales_i_1 = Psitilde_gamma[np.arange(M),0].T*Atilde_ki_1i_1
        local_scales_i_2 = Psitilde_gamma[np.arange(M),1].T*Atilde_ki_2i_2
        dlocal_scales = np.log(local_scales_i_1/local_scales_i_2+1)-np.log(2)

        fig2 = plt.figure(2, figsize=figsize)
        figManager = plt.get_current_fig_manager()
        figManager.window.showMaximized()

        if is_3d_data:
            ax = [None, None]
            ax[0] = fig2.add_subplot(321, projection='3d')
            ax[1] = fig2.add_subplot(323, projection='3d')
            p = ax[0].scatter(X[:,0], X[:,1], X[:,2], s=s, c=zeta, cmap='jet')
            cb = plt.colorbar(p, ax=ax[0])
            ax[0].autoscale()
            ax[0].set_title('$x_k$ colored by $\\widetilde{\\zeta}_{c_kc_k}$')
        else:
            plt.subplot(321)
            plt.scatter(X[:,0], X[:,1], s=s, c=zeta, cmap='jet')
            plt.colorbar()
            plt.axis('image')
            plt.title('$x_k$ colored by $\\widetilde{\\zeta}_{c_kc_k}$')

        plt.subplot(322)
        plt.cla()
        plt.plot(prctiles, np.percentile(zetatilde, prctiles), 'bo-')
        plt.xlabel('percentiles')
        plt.title('$\\widetilde{\\zeta}_{mm}$')
        plt.show()
        fig2.canvas.draw()
        fig2.canvas.flush_events()

        while True:
            plt.subplot(322)

            to_exit = plt.waitforbuttonpress(timeout=20)
            if to_exit is None:
                print('Timed out')
                break

            if to_exit:
                plt.close()
                return

            zetatilde_m = plt.ginput(1)
            print(zetatilde_m)
            thresh = zetatilde_m[0][1]

            plt.cla()
            plt.plot(prctiles, np.percentile(zetatilde, prctiles), 'bo-')
            plt.plot([0,100], [thresh]*2, 'r-')
            plt.xlabel('percentiles')
            plt.title('$\\widetilde{\\zeta}_{mm}$, threshold = %f' % thresh)
            plt.show()
            fig2.canvas.draw()
            fig2.canvas.flush_events()

            low_dist_mask = zeta <= thresh
            high_dist_mask = zeta >= thresh

            if is_3d_data:
                ax[1].cla()
                ax[1].scatter(X[low_dist_mask,0], X[low_dist_mask,1], X[low_dist_mask,2], s=s, c='b')
                ax[1].scatter(X[high_dist_mask,0], X[high_dist_mask,1], X[high_dist_mask,2], s=s, c='r')
                ax[1].autoscale()
                ax[1].set_title('blue = low distortion, Red = high distortion')
            else:
                plt.subplot(323)
                plt.cla()
                plt.scatter(X[low_dist_mask,0], X[low_dist_mask,1], s=s, c='b')
                plt.scatter(X[high_dist_mask,0], X[high_dist_mask,1], s=s, c='r')
                plt.title('blue = low distortion, Red = high distortion')
                plt.axis('image')

            low_dist_mask = zetatilde <= thresh
            high_dist_mask = zetatilde >= thresh

            plt.subplot(325)
            plt.cla() 
            plt.boxplot([angles[low_dist_mask],angles[high_dist_mask]],
                        labels=['low $\\widetilde{\\zeta}_{mm}$','high $\\widetilde{\\zeta}_{mm}$'], notch=True,
                               vert=False, patch_artist=True)
            plt.title('$|\\widetilde{A}_{ki_1i_2}|/(\\widetilde{A}_{ki_1i_1}\\widetilde{A}_{ki_2i_2})$')
            plt.show()

            plt.subplot(326)
            plt.cla()
            plt.boxplot([dlocal_scales[low_dist_mask],dlocal_scales[high_dist_mask]],
                        labels=['low $\\widetilde{\\zeta}_{mm}$','high $\\widetilde{\\zeta}_{mm}$'], notch=True,
                               vert=False, patch_artist=True)
            plt.title('$\\log(\\gamma_{ki_1}\\sqrt{\\widetilde{A}_{ki_1i_1}} / \
                      \\gamma_{ki_2}\\sqrt{\\widetilde{A}_{ki_2i_2}}+1)$')
            plt.show()
            fig2.canvas.draw()
            fig2.canvas.flush_events()
            

class Datasets:
    def __init__(self):
        pass
    def rectangleGrid(self, ar=16, RES=100):
        sideLx = np.sqrt(ar)
        sideLy = 1/sideLx
        RESx = sideLx*RES+1
        RESy = sideLy*RES+1
        x = np.linspace(0, sideLx, RESx)
        y = np.linspace(0, sideLy, RESy)
        xv, yv = np.meshgrid(x, y);
        xv = xv.flatten('F')[:,np.newaxis]
        yv = yv.flatten('F')[:,np.newaxis]
        X = np.concatenate([xv,yv], axis=1)
        labelsMat = X
        print('X.shape = ', X.shape)
        return X, labelsMat
    
    def sphere(self, n=10000):
        R = np.sqrt(1/(4*np.pi))
        indices = np.arange(n)+0.5
        phiv = np.arccos(1 - 2*indices/n)
        phiv = phiv[:,np.newaxis]
        thetav = np.pi*(1 + np.sqrt(5))*indices
        thetav = thetav[:,np.newaxis]
        X = np.concatenate([np.sin(phiv)*np.cos(thetav),
                            np.sin(phiv)*np.sin(thetav),
                            np.cos(phiv)], axis=1)
        X = X*R;
        labelsMat = np.concatenate([thetav, phiv], axis=1)
        print('X.shape = ', X.shape)
        return X, labelsMat
        
def graph_laplacian(d_e, k_nn, k_tune, gl_type,
                    return_diag=False, use_out_degree=False):
    assert k_nn > k_tune, "k_nn must be greater than k_tune."
    assert gl_type in ['normed','unnorm','random_walk'],\
            "gl_type should be one of {'normed','unnorm','random_walk'}"
    
    n = d_e.shape[0]
    # Find k_nn nearest neighbors excluding self
    neigh = NearestNeighbors(n_neighbors=k_nn,
                             metric='precomputed',
                             algorithm='brute')
    neigh.fit(d_e)
    neigh_dist, neigh_ind = neigh.kneighbors()
    
    # Compute tuning values for each pair of neighbors
    sigma = neigh_dist[:,k_tune-1].flatten()
    autotune = sigma[neigh_ind]*sigma[:,np.newaxis]
    
    # Compute kernel matrix
    eps = np.finfo(np.float64).eps
    K = np.exp(-neigh_dist**2/(autotune+eps))
    
    # Convert to sparse matrices
    neigh_ind = neigh_ind.flatten()
    source_ind = np.repeat(np.arange(n),k_nn)
    K = coo_matrix((K.flatten(),(source_ind,neigh_ind)),shape=(n,n))

    # Compute and return graph Laplacian based on gl_type
    if gl_type == 'normed':
        return laplacian(K, normed=True,
                         return_diag=return_diag,
                         use_out_degree=use_out_degree)
    elif gl_type == 'unnorm':
        return laplacian(K, normed=False,
                         return_diag=return_diag,
                         use_out_degree=use_out_degree)
    elif gl_type == 'random_walk':
        L, D = laplacian(K, normed=False,
                         return_diag=True,
                         use_out_degree=use_out_degree)
        L.data /= D[L.row]
        if return_diag:
            return L, D
        else:
            return L
    

def local_views_in_ambient_space(d_e, k):
    neigh = NearestNeighbors(n_neighbors=k,
                             metric='precomputed',
                             algorithm='brute')
    neigh.fit(d_e)
    neigh_dist, neigh_ind = neigh.kneighbors()
    epsilon = neigh_dist[:,[k-1]]
    U = d_e < (epsilon + 1e-12)
    return U, epsilon

def compute_Atilde(phi, d_e, U, epsilon, p, d, print_prop = 0.25):
    n, N = phi.shape
    print_freq = np.int(n*print_prop)
    
    # Compute G
    t = 0.5*((epsilon**2)/chi2.ppf(p, df=d))
    G = np.exp(-d_e**2/(4*t))*U
    G = G/(np.sum(G,1)[:,np.newaxis])

    # Compute Gtilde (Gtilde_k = (1/t_k)[G_{k1},...,G_{kn}])
    Gtilde = G/(2*t)
    
    Atilde=np.zeros((n,N,N))
    for k in range(n):
        if print_freq and np.mod(k,print_freq)==0:
            print('Atilde_k: %d points processed...' % k)
        U_k = U[k,:]==1
        dphi_k = phi[U_k,:]-phi[k,:]
        Atilde[k,:,:] = np.dot(dphi_k.T, dphi_k*(Gtilde[k,U_k][:,np.newaxis]))
    
    print('Atilde_k: all points processed...')
    return Atilde
    
class LDLE:
    '''
        X: Input data with examples in rows and features in columns.
        d_e: Dissimilarity matrix.
             Either of X or d_e must be provided.
        k_nn: number of nearest neighbours to consider for graph Laplacian.
        k_tune: self-tuning parameter for graph Laplacian.
        gl_type: The type of graph Laplacian to use {normed, unnorm, random_walk}.
        k: Distance to the kth nearest neighbor is used to
           construct local view in the ambient space.
        p: probability mass to capture.
        d: intrinsic dimension of the manifold.
        tau: percentile for thresholds.
        delta: fraction for thresholds.
    '''
    def __init__(self,
                 X = None,
                 d_e = None,
                 k_nn = 48,
                 k_tune = 6,
                 gl_type = 'unnorm',
                 N = 100,
                 k = 24,
                 p = 0.99,
                 d = 2,
                 tau = 50,
                 delta = 0.9,
                 eta_min = 5,
                 eta_max = 100):
        assert X is not None or d_e is not None, "Either X or d_e should be provided."
        self.X = X
        self.d_e = d_e
        if d_e is None:
            self.d_e = squareform(pdist(X))
        self.k_nn = k_nn
        self.k_tune = k_tune
        self.gl_type = gl_type
        self.N = N
        self.k = k
        self.p = p
        self.d = d
        self.tau = tau
        self.delta = delta
        self.eta_min = eta_min
        self.eta_max = eta_max
        
        # Construct graph Laplacian
        self.L = graph_laplacian(self.d_e, self.k_nn,
                                 self.k_tune, self.gl_type)
        
        # Eigendecomposition of graph Laplacian
        # Note: Eigenvalues are returned sorted.
        # Following is needed for reproducibility of lmbda and phi
        np.random.seed(2)
        v0 = np.random.uniform(0,1,self.L.shape[0])
        if self.gl_type != 'random_walk':
            self.lmbda, self.phi = eigsh(self.L, k=self.N+1, v0=v0, which='SM')
        else:
            self.lmbda, self.phi = eigs(self.L, k=self.N+1, v0=v0, which='SM')
        
        # Ignore the trivial eigenvalue and eigenvector
        self.lmbda = self.lmbda[1:]
        self.phi = self.phi[:,1:]
        
        # Construct local views in the ambient space
        # and obtain radius of each view
        self.U, epsilon = local_views_in_ambient_space(self.d_e, self.k)
        
        # Compute Atilde
        self.Atilde = compute_Atilde(self.phi, self.d_e, self.U, epsilon, self.p, self.d)
        
        # Compute gamma
        self.gamma = np.sqrt(1/(np.dot(self.U,self.phi**2)/np.sum(self.U,1)[:,np.newaxis]))
        
        # Compute LDLE: Low Distortion Local Eigenmaps
        self.Psi_gamma0, self.Psi_i0, self.zeta0 = self.compute_LDLE()
        
        # Postprocess LDLE
        self.Psi_gamma, self.Psi_i, self.zeta = self.postprocess_LDLE()
        
        # Compute beta
        self.beta = self.compute_beta()
            
        # Clustering to obtain intermediate views
        self.c, self.Utilde, self.Psitilde_i, self.Psitilde_gamma,\
        self.betatilde, self.zetatilde = self.construct_intermediate_views()
        
    
    def compute_LDLE(self, print_prop = 0.25):
        n, N = self.phi.shape
        N = self.phi.shape[1]
        d = self.d
        tau = self.tau
        delta = self.delta
        
        print_freq = np.int(n*print_prop)
        
        Psi_gamma = np.zeros((n,d))
        Psi_i = np.zeros((n,d),dtype='int')
        zeta = np.zeros(n)

        for k in range(n):
            if print_freq and np.mod(k, print_freq)==0:
                print('Psi,zeta: %d points processed...' % k)
            
            i = np.zeros(d, dtype='int')
            
            Atilde_k = self.Atilde[k,:,:]
            gamma_k = self.gamma[k,:]
            
            Atikde_kii = Atilde_k.diagonal()
            
            theta_1 = np.percentile(Atikde_kii, tau)
            Stilde_k = Atikde_kii >= theta_1
            
            r_1 = np.argmax(Stilde_k)
            temp = gamma_k * np.abs(Atilde_k[:,r_1])
            alpha_1 = np.max(temp * Stilde_k)
            i[0] = np.argmax((temp >= delta*alpha_1) & (Stilde_k))

            for s in range(1,d):
                i_prev = i[0:s];
                temp = inv(Atilde_k[np.ix_(i_prev,i_prev)])
                
                Hs_kii = Atikde_kii - np.sum(Atilde_k[:,i_prev] * np.dot(temp, Atilde_k[i_prev,:]).T, 1)
                temp_ = Hs_kii[Stilde_k]
                theta_s = np.percentile(temp_, tau)
                
                theta_s=np.max([theta_s,np.min([np.max(temp_),1e-4])])

                r_s = np.argmax((Hs_kii>=theta_s) & Stilde_k)
                Hs_kir_s = Atilde_k[:,[r_s]] - np.dot(Atilde_k[:,i_prev], np.dot(temp, Atilde_k[i_prev,r_s][:,np.newaxis]))
                temp = gamma_k * np.abs(Hs_kir_s.flatten())
                alpha_s = np.max(temp * Stilde_k)
                i[s]=np.argmax((temp >= delta*alpha_s) & Stilde_k);
            
            Psi_gamma[k,:] = gamma_k[i]
            Psi_i[k,:] = i
            U_k = self.U[k,:]==1
            zeta[k] = compute_zeta(self.d_e[np.ix_(U_k,U_k)], eval_param(self.phi, Psi_gamma, Psi_i, k, U_k))
            
        print('Psi,zeta: all points processed...')
        return Psi_gamma, Psi_i, zeta
    
    def postprocess_LDLE(self):
        n = self.d_e.shape[0]
        Psi_i = self.Psi_i0
        Psi_gamma = self.Psi_gamma0
        zeta = self.zeta0
        
        converged = False
        itr = 1
        param_changed_old = 1
        LP = np.arange(n)
        while not converged:
            replace_with = np.arange(n)
            param_changed_new = np.zeros(n)
            for k in range(n):
                U_k = self.U[k,:]==1
                cand_k = np.where(U_k & (param_changed_old==1))[0]
                for j in cand_k:
                    Psi_j_on_U_k = eval_param(self.phi, Psi_gamma, Psi_i, j, U_k)
                    zeta_kj = compute_zeta(self.d_e[np.ix_(U_k,U_k)], Psi_j_on_U_k)
                    if zeta_kj < zeta[k]:
                        zeta[k] = zeta_kj
                        param_changed_new[k] = 1
                        replace_with[k] = j
                        LP[k] = LP[j]
                        
            Psi_i = Psi_i[replace_with,:]
            Psi_gamma = Psi_gamma[replace_with,:]
            param_changed_old = param_changed_new
            converged = np.sum(param_changed_new)==0
            print("After iter %d, max distortion is %f" % (itr, np.max(zeta)))
            itr = itr + 1
        
        return Psi_gamma, Psi_i, zeta
    
    def compute_beta(self):
        n = self.phi.shape[0]
        beta = np.zeros(n)
        for k in range(n):
            U_k = self.U[k,:]==1
            d_e_U_k = self.d_e[np.ix_(U_k,U_k)]
            if d_e_U_k.shape[0]==1:
                self.beta[k] = 1
            else:
                Psi_k_on_U_k = eval_param(self.phi, self.Psi_gamma, self.Psi_i, k, U_k)
                beta[k]=np.median(squareform(d_e_U_k))/np.median(pdist(Psi_k_on_U_k))
        return beta
    
    def construct_intermediate_views(self):
        n, N = self.phi.shape
        c = np.arange(n)
        n_C = np.zeros(n) + 1
        Utilde = np.copy(self.U)
        
        for eta in range(2,self.eta_min+1):
            print('#nodes in views with sz < %d = %d' % (eta, np.sum(n_C[c]<eta)));
            
            cost = np.zeros(n)+np.inf
            dest = np.zeros(n,dtype='int')-1
            for k in range(n):
                cost[k], dest[k] = cost_of_moving_node(k, self.d_e, self.U, self.phi, self.Psi_gamma,
                                                       self.Psi_i, c, n_C, Utilde, eta, self.eta_max)
            k = np.argmin(cost)
            cost_star = cost[k]
            while cost_star < np.inf:
                s = c[k]
                dest_k = dest[k]
                c[k] = dest_k
                n_C[s] -= 1
                n_C[dest_k] += 1
                Utilde[dest_k,:] = (Utilde[dest_k,:]==1) | (self.U[k,:])
                Utilde[s,:] = np.any(self.U[c==s,:],0)
                #print('Moved x_%d from C_%d [%d] to C_%d [%d]' % (k, s, n_C[s], dest_k, n_C[dest_k]))
                
                S = np.where((c==dest_k) | (dest==dest_k) | np.any(self.U[:,c==s],1))[0].tolist()
                
                for k in S:
                    cost[k], dest[k] = cost_of_moving_node(k, self.d_e, self.U, self.phi, self.Psi_gamma,
                                                       self.Psi_i, c, n_C, Utilde, eta, self.eta_max)
                
                k = np.argmin(cost)
                cost_star = cost[k]
            print('Remaining #nodes in views with sz < %d = %d' % (eta, np.sum(n_C[c]<eta)))
        
        non_empty_C = n_C > 0
        M = np.sum(non_empty_C)
        old_to_new_map = np.arange(n)
        old_to_new_map[non_empty_C] = np.arange(M)
        c = old_to_new_map[c]
        n_C = n_C[non_empty_C]
        
        Psitilde_i = self.Psi_i[non_empty_C,:]
        Psitilde_gamma = self.Psi_gamma[non_empty_C,:]
        betatilde = self.beta[non_empty_C]
        
        Utilde = np.zeros((M,n))
        for m in range(M):
            Utilde[m,:] = np.any(self.U[c==m,:], 0)
        
        Utilde = Utilde==1
        self.Utilde = Utilde
        
        zetatilde = np.zeros(M);
        for m in range(M):
            Utilde_m = Utilde[m,:]
            zetatilde[m] = compute_zeta(self.d_e[np.ix_(Utilde_m,Utilde_m)],
                                        eval_param(self.phi, Psitilde_gamma,
                                                   Psitilde_i, m, Utilde_m))
        
        print("After clustering, max distortion is %f" % (np.max(zetatilde)))
        
        return c, Utilde, Psitilde_i, Psitilde_gamma, betatilde, zetatilde

In [2]:
X, labelsMat = Datasets().rectangleGrid(RES=100)
#X, labelsMat = Datasets().sphere()
ldle = LDLE(X=X)

X.shape =  (10426, 2)
Atilde_k: 0 points processed...
Atilde_k: 2606 points processed...
Atilde_k: 5212 points processed...
Atilde_k: 7818 points processed...
Atilde_k: 10424 points processed...
Atilde_k: all points processed...
Psi,zeta: 0 points processed...
Psi,zeta: 2606 points processed...
Psi,zeta: 5212 points processed...
Psi,zeta: 7818 points processed...
Psi,zeta: 10424 points processed...
Psi,zeta: all points processed...
After iter 1, max distortion is 18.190035
After iter 2, max distortion is 10.813537
After iter 3, max distortion is 9.786342
After iter 4, max distortion is 9.786342
After iter 5, max distortion is 9.786342
After iter 6, max distortion is 9.786342
After iter 7, max distortion is 9.786342
#nodes in views with sz < 2 = 10426
Remaining #nodes in views with sz < 2 = 0
#nodes in views with sz < 3 = 1876
Remaining #nodes in views with sz < 3 = 0
#nodes in views with sz < 4 = 1278
Remaining #nodes in views with sz < 4 = 0
#nodes in views with sz < 5 = 1088
Remainin

In [None]:
np.max(ldle.zeta)

In [28]:
%matplotlib qt
visualize = Visualize()

In [None]:
visualize.data(X, labelsMat[:,0])
visualize.eigenvalues(ldle.lmbda)

In [None]:
visualize.eigenvector(X, ldle.phi, 99)

In [None]:
visualize.distortion(X, ldle.zeta, s=30)

In [None]:
visualize.local_views(ldle.X, ldle.phi, ldle.U, ldle.gamma, ldle.Atilde,
                      ldle.Psi_gamma, ldle.Psi_i, ldle.zeta)

In [None]:
visualize.intermediate_views(ldle.X, ldle.phi, ldle.Utilde, ldle.gamma, ldle.Atilde,

                             ldle.Psitilde_gamma, ldle.Psitilde_i, ldle.zetatilde, ldle.c)

In [None]:
visualize.chosen_eigevec_inds_for_local_views(ldle.X, ldle.Psi_i)

In [12]:
visualize.chosen_eigevec_inds_for_intermediate_views(ldle.X, ldle.Psitilde_i, ldle.c)

In [15]:
visualize.compare_local_high_low_distortion(ldle.X, ldle.Atilde, ldle.Psi_gamma, ldle.Psi_i, ldle.zeta)



[(77.38366935483876, 4.451137514056377)]




[(80.4422043010753, 5.559282959708288)]




[(75.56317204301078, 3.827805700877178)]




Timed out


In [None]:
visualize.compare_intermediate_high_low_distortion(ldle.X, ldle.Atilde, ldle.Psitilde_gamma,
                                                   ldle.Psitilde_i, ldle.zetatilde, ldle.c)



[(73.35846774193553, 3.5452067814643833)]




[(73.61155913978499, 4.767649792104258)]




[(76.05107526881724, 6.945126404806533)]




[(59.462365591397884, 2.055354362247037)]




[(40.10887096774198, 1.6733409214220778)]




[(27.26075268817206, 1.4441328569271015)]




[(11.647849462365627, 1.3295288246796133)]




[(96.21774193548393, 7.861958662786439)]




[(67.26881720430111, 3.46880409329939)]




In [None]:
np.sum(np.abs(ldle.Psitilde_i-ldle.Psi_i[ldle.c,:]))