# Sparse energy auto-encoders

* The definition of the algortihm behind our sparse energy auto-encoder model.
* It is an unsupervised feature extraction tool which tries to find a good sparse representation in an efficient manner.
* This notebook is meant to be imported by other notebooks for applications to image or audio data.
* Modeled after sklearn Estimator class so that it can be integrated into an sklearn Pipeline. Note that matrix dimensions are inverted (code vs math) to follow sklearn conventions.

## Algorithm

General problem:  
* given $X \in R^{n \times N}$,
* solve $\min\limits_{Z \in R^{m \times N}, D \in R^{n \times m}, E \in R^{m \times n}} \frac{\lambda_d}{2} \|X - DZ\|_F^2 + \frac{\lambda_e}{2} \|Z - EX\|_F^2 + \|Z\|_1$
* s.t. $\|d_i\|_2 \leq 1$, $\|e_k\|_2 \leq 1$, $i = 1, \ldots, m$, $k = 1, \ldots, n$

which can be reduced to sparse coding with dictionary learning:  
* given $X \in R^{n \times N}$,
* solve $\min\limits_{Z \in R^{m \times N}, D \in R^{n \times m}} \frac{\lambda_d}{2} \|X - DZ\|_F^2 + \|Z\|_1$
* s.t. $\|d_i\|_2 \leq 1$, $i = 1, \ldots, m$

Observations:
* Almost ten times faster (on comparison_xavier) using optimized linear algebra subroutines:
    * None: 9916s
    * ATLAS: 1335s (is memory bandwith limited)
    * OpenBLAS: 1371s (seems more CPU intensive than ATLAS)

Open questions:
* First optimize for Z (last impl) or first for D/E (new impl) ?
    * Seem to converge much faster if Z optimized last (see comparison_xavier).
    * But two times slower.
    * In fit we optimize for parameters D, E so it makes sense to optimize them last.
* Fast evaluation of la.norm(Z.T.dot(Z)). Cumulative to save memory ?
* Consider adding an option for $E = D^T$
* Use single precision, i.e. float32 ?

In [None]:
import numpy as np
import numpy.linalg as la
from pyunlocbox import functions, solvers
import matplotlib.pyplot as plt
%matplotlib inline

def _normalize(X, axis=1):
    """Normalize the selected axis of an ndarray to unit norm."""
    return X / np.sqrt(np.sum(X**2, axis))[:,np.newaxis]

class auto_encoder():
    """Sparse energy auto-encoder."""
    
    def __init__(self, m=100, ld=None, le=None, lg=None,
                 rtol=1e-3, xtol=None, N_inner=100, N_outer=15):
        """
        Model hyper-parameters and solver stopping criteria.
        
        Model hyper-parameters:
            m:  number of atoms in the dictionary, sparse code length
            ld: weigth of the dictionary l2 penalty
            le: weigth of the encoder l2 penalty
            lg: weight of the graph smoothness
        
        Stopping criteria::
            rtol: objective function convergence
            xtol: model parameters convergence
            N_inner: hard limit of inner iterations
            N_outer: hard limit of outer iterations
        """
        self.m = m
        self.ld = ld
        self.le = le
        self.lg = lg
        self.N_outer = N_outer
        
        # Solver common parameters.
        self.params = {'rtol':       rtol,
                       'xtol':       xtol,
                       'maxit':      N_inner,
                       'verbosity': 'NONE'}

    def _convex_functions(self, X, Z):
        """Define convex functions."""
        
        f = functions.proj_b2()
        self.f = functions.func()
        self.f._eval = lambda X: 0
        self.f._prox = lambda X,_: f._prox(X.T, 1).T
        #self.f._prox = lambda X,_: _normalize(X)
        
        if self.ld is not None:
            self.g_d = functions.norm_l2(lambda_=self.ld/2., A=Z, y=X, tight=False)
            g_z = functions.norm_l2(lambda_=self.ld/2., A=self.D.T, y=X.T, tight=False)
        if self.le is not None:
            self.h_e = functions.norm_l2(lambda_=self.le/2., A=X, y=Z, tight=False)
            h_z = functions.norm_l2(lambda_=self.le/2., y=lambda: X.dot(self.E).T, tight=True)
        
        if self.ld is not None and self.le is None:
            self.gh_z = g_z
        elif self.ld is None and self.le is not None:
            self.gh_z = h_z
        elif self.ld is not None and self.le is not None:
            self.gh_z = functions.func()
            self.gh_z._eval = lambda Z: g_z._eval(Z) + h_z._eval(Z)
            self.gh_z._grad = lambda Z: g_z._grad(Z) + h_z._grad(Z)
        else:
            raise ValueError('Either ld or le should be defined.')
            
        self.i_z = functions.norm_l1()

    def _minD(self, X, Z):
        """Convex minimization for D."""
        
        # Lipschitz continuous gradient. Faster if larger dim is 'inside'.
        L = self.ld * la.norm(Z.T.dot(Z))
        
        solver = solvers.forward_backward(step=1./L, method='FISTA')
        ret = solvers.solve([self.g_d, self.f], self.D, solver, **self.params)
        
        self.objective_d.extend(ret['objective'])
        self.objective_z.extend([[0,0]] * len(ret['objective']))
        self.objective_e.extend([[0,0]] * len(ret['objective']))
    
    def _minE(self, X, Z):
        """Convex minimization for E."""
        
        # Lipschitz continuous gradient. Faster if larger dim is 'inside'.
        L = self.le * la.norm(X.T.dot(X))
        
        solver = solvers.forward_backward(step=1./L, method='FISTA')
        ret = solvers.solve([self.h_e, self.f], self.E, solver, **self.params)
        
        self.objective_e.extend(ret['objective'])
        self.objective_z.extend([[0,0]] * len(ret['objective']))
        self.objective_d.extend([[0,0]] * len(ret['objective']))
    
    def _minZ(self, X, Z):
        """Convex minimization for Z."""
        
        L_e = self.le if self.le is not None else 0
        L_d = self.ld * la.norm(self.D.T.dot(self.D)) if self.ld is not None else 0
        L = L_d + L_e
        
        solver = solvers.forward_backward(step=1./L, method='FISTA')
        ret = solvers.solve([self.gh_z, self.i_z], Z.T, solver, **self.params)
        
        self.objective_z.extend(ret['objective'])
        self.objective_d.extend([[0,0]] * len(ret['objective']))
        self.objective_e.extend([[0,0]] * len(ret['objective']))
        
    def fit_transform(self, X):
        """
        Fit the model parameters (dictionary, encoder and graph)
        given training data.
        
        Parameters
        ----------
        X : ndarray, shape (N, n)
            Training vectors, where N is the number of samples
            and n is the number of features.
            
        Returns
        -------
        Z : ndarray, shape (N, m)
            Sparse codes (a by-product of training), where N
            is the number of samples and m is the number of atoms.
        """
        N, n = X.shape
        
        # Model parameters initialization.
        if self.ld is not None:
            self.D = _normalize(np.random.uniform(size=(self.m, n)).astype(X.dtype))
        if self.le is not None:
            self.E = _normalize(np.random.uniform(size=(n, self.m)).astype(X.dtype))
        
        # Initial predictions.
        #Z = np.random.uniform(size=(N, self.m)).astype(X.dtype)
        Z = np.zeros(shape=(N, self.m), dtype=X.dtype)
        
        # Initialize convex functions.
        self._convex_functions(X, Z)
        
        # Objective functions.
        self.objective = []
        self.objective_z = []
        self.objective_d = []
        self.objective_e = []
        
        # Multi-variate non-convex optimization (outer loop).
        for _ in range(self.N_outer):

            self._minZ(X, Z)

            if self.ld is not None:
                self._minD(X, Z)

            if self.le is not None:
                self._minE(X, Z)

            # Global objective function.
            self.objective.append(self.gh_z.eval(Z.T) + self.i_z.eval(Z.T))
            
        return Z
    
    def fit(self, X):
        """Fit to data without returning the transformed data."""
        self.fit_transform(X)
    
    def transform(self, X):
        """Predict sparse codes for each sample in X."""
        return self._transform_exact(X)
        
    def _transform_exact(self, X):
        """Most accurate but slowest prediction."""
        N = X.shape[0]
        Z = np.random.uniform(size=(N, self.m)).astype(X.dtype)
        self._convex_functions(X, Z)
        self._minZ(X, Z)
        return Z
    
    def _transform_approx(self, X):
        """Much faster approximation using only the encoder."""
        raise NotImplementedError('Not yet implemented')
    
    def inverse_transform(self, Z):
        """
        Return the data corresponding to the given sparse codes using
        the learned dictionary.
        """
        raise NotImplementedError('Not yet implemented')
    
    def plot_objective(self):
        """Plot the objective (cost, loss, energy) functions."""
        plt.figure(figsize=(8,5))
        plt.semilogy(np.array(self.objective_z)[:, 0], label='Z: data term')
        plt.semilogy(np.array(self.objective_z)[:, 1], label='Z: prior term')
        #plt.semilogy(np.sum(objective[:,0:2], axis=1), label='Z: sum')
        plt.semilogy(np.array(self.objective_d)[:, 0], label='D: data term')
        plt.semilogy(np.array(self.objective_e)[:, 0], label='E: data term')
        niter = np.shape(self.objective_z)[0]
        plt.xlim(0, niter-1)
        plt.title('Sub-problems convergence')
        plt.xlabel('Iteration number (inner loops)')
        plt.ylabel('Objective function value')
        plt.grid(True); plt.legend(); plt.show()
        print('Inner loop: {} iterations'.format(niter))

        plt.figure(figsize=(8,5))
        plt.plot(self.objective)
        niter = len(self.objective)
        plt.xlim(0, niter-1)
        plt.title('Global convergence')
        plt.xlabel('Iteration number (outer loop)')
        plt.ylabel('Objective function value')
        plt.grid(True); plt.show()
        print('Outer loop: {} iterations\n'.format(niter))

## Tools for solution analysis

Tools to show model parameters, sparse codes and objective function. The *auto_encoder* class solely contains the core algorithm (and a visualization of the convergence).

In [None]:
def objective(X, Z, D, ld):
    """Plot the value of the objective function."""
    g_z = functions.norm_l1()
    f_z = functions.norm_l2(lambda_=ld/2., A=D.T, y=X.T, tight=False)
    f_d = functions.norm_l2(lambda_=ld/2., A=Z, y=X, tight=False)

    g_z = g_z.eval(Z.T)
    f_z = f_z.eval(Z.T)
    f_d = f_d.eval(D)
    assert abs(f_z - f_d) / f_z < 1e-5
    
    print('||Z||_1 = {:e}'.format(g_z))
    print('||X-DZ||_2^2 = {:e}'.format(f_z))
    print('||Z||_1 + ||X-DZ||_2^2 = {:e}'.format(g_z + f_z))

def sparse_codes(Z, tol=0):
    """Show the sparsity of the sparse codes."""
    N, m = Z.shape
    
    print('Z in [{}, {}]'.format(np.min(Z), np.max(Z)))
    
    if tol is 0:
        nnz = np.count_nonzero(Z)
    else:
        nnz = np.sum(np.abs(Z) > tol)
    print('Sparsity of Z: {:,} non-zero entries out of {:,} entries, '
          'i.e. {:.1f}%.'.format(nnz, Z.size, 100.*nnz/Z.size))

    plt.figure(figsize=(8,5))
    plt.spy(Z.T, precision=tol, aspect='auto')
    plt.xlabel('N = {} samples'.format(N))
    plt.ylabel('m = {} atoms'.format(m))
    plt.show()
    
def dictionary(D, tol=1e-7):
    """Show the norms and sparsity of the learned dictionary."""
    m, n = D.shape

    print('D in [{}, {}]'.format(np.min(D), np.max(D)))
    
    d = np.sqrt(np.sum(D**2, axis=1))
    print('d in [{}, {}]'.format(np.min(d), np.max(d)))
    print('Constraints on D: {}'.format(np.alltrue(d <= 1+tol)))
    
    plt.figure(figsize=(8,5))
    plt.plot(d, 'b.')
    #plt.ylim(0.5, 1.5)
    plt.xlim(0, m-1)
    plt.title('Dictionary atom norms')
    plt.xlabel('Atom [1,m]')
    plt.ylabel('Norm [0,1]')
    plt.grid(True); plt.show()
    plt.show()

    plt.figure(figsize=(8,5))
    plt.spy(D.T, precision=1e-2, aspect='auto')
    plt.xlabel('m = {} atoms'.format(m))
    plt.ylabel('data dimensionality of n = {}'.format(n))
    plt.show()
    
    #plt.scatter to show intensity
    
def atoms(D, Np=None):
    """
    Show dictionary atoms.
    
    2D atoms if Np is not None, else 1D atoms.
    """
    m, n = D.shape
    
    fig = plt.figure(figsize=(8,8))
    Nx = np.ceil(np.sqrt(m))
    Ny = np.ceil(m / float(Nx))
    for k in np.arange(m):
        ax = fig.add_subplot(Ny, Nx, k)
        if Np is not None:
            img = D[k,:].reshape(Np, Np)
            ax.imshow(img, cmap='gray')  # vmin=0, vmax=1 to disable normalization.
            ax.axis('off')
        else:
            ax.plot(D[k,:])
            ax.set_xlim(0, n-1)
            ax.set_ylim(-1, 1)
            ax.set_xticks([])
            ax.set_yticks([])

## Unit tests

Test the auto-encoder class and tools.

In [None]:
if False:
    # ldd numpy/core/_dotblas.so
    try:
        import numpy.core._dotblas
        print 'fast BLAS'
    except ImportError:
        print 'slow BLAS'

    print np.__version__
    np.__config__.show()

if False:
#if __name__ is '__main__':
    import time

    # Data.
    N, n = 11, 25
    X = np.random.normal(size=(N, n))

    # Algorithm.
    ae = auto_encoder(m=16, le=1, rtol=1e-5, xtol=None).fit(X)
    ae = auto_encoder(m=16, ld=1, rtol=1e-5, xtol=None).fit(X)
    ae = auto_encoder(m=16, ld=1, le=1, rtol=None, xtol=1e-5)
    tstart = time.time()
    Z = ae.fit_transform(X)
    print('Elapsed time: {:.3f} seconds'.format(time.time() - tstart))
    ae.plot_objective()
    
    assert la.norm(Z - ae.transform(X)) / np.sqrt(Z.size) < 1e-3

    # Results visualization.
    objective(X, Z, ae.D, 1)
    sparse_codes(Z)
    dictionary(ae.D)
    atoms(ae.D, 5)  # 2D atoms.
    atoms(ae.D)  # 1D atoms.