# Energy auto-encoder: comparison with Xavier Primal-Dual matlab implementation

## Setup

In [None]:
%matplotlib inline
import numpy as np
import numpy.linalg as la
import scipy.io
import matplotlib.pyplot as plt
from pyunlocbox import functions, solvers
import time
import scipy, matplotlib, pyunlocbox  # For versions only.
print('Software versions:')
for pkg in [np, matplotlib, scipy, pyunlocbox]:
    print('  %s: %s' % (pkg.__name__, pkg.__version__))

## Hyper-parameters

* The $\lambda$ are the relative importance of each term in the composite objective function.

In [None]:
l_d = 10  # Xavier sets the weight of the L1 regularization to 1e-1.

## Data

* The set of data vectors $X \in R^{n \times N}$ is given by patches extracted from a grayscale image.
* There is as many patches as pixels in the image.
* The saved patches already have zero mean.

In [None]:
mat = scipy.io.loadmat('data/xavier_X.mat')
X = mat['X']

n, N = X.shape
Np = np.sqrt(n)
print('N = %d samples with dimensionality n = %d (patches of %dx%d).' % (N, n, Np, Np))

plt.figure(figsize=(8,5))
patches = [24, 1000, 2004, 10782]
for k in range(len(patches)):
    patch = patches[k]
    img = np.reshape(X[:, patch], (Np, Np))
    plt.subplot(1, 4, k+1)
    plt.imshow(img, cmap='gray')
    plt.title('Patch %d' % patch)
plt.show()

## Initial conditions

* $Z$ is drawn from a uniform distribution in ]0,1[.
* Same for $D$. Its columns were then normalized to unit L2 norm.
* The sparse code dimensionality $m$ should be greater than $n$ for an overcomplete representation but much smaller than $N$ to avoid over-fitting.

In [None]:
mat = scipy.io.loadmat('data/xavier_initZD.mat')
Zinit = mat['Zinit']
Dinit = mat['Dinit']

m, N = Zinit.shape
n, m = Dinit.shape
print('Sparse code dimensionality m = %d --> %s dictionary' % (m, 'overcomplete' if m > n else 'undercomplete'))

print('mean(Z) = %f' % np.mean(Zinit))

d = np.sqrt(np.sum(Dinit*Dinit, axis=0))
print('Constraints on D: %s' % np.alltrue(d <= 1+1e-15))

## Algorithm

Given $X \in R^{n \times N}$, solve $\min\limits_{Z \in R^{m \times N}, D \in R^{n \times m}} \frac{\lambda_d}{2} \|X - DZ\|_F^2 + \|Z\|_1$ s.t. $\|d_i\|_2 \leq 1$, $i = 1, \ldots, m$

In [None]:
# Solver numeric parameters.
N_outer = 15  # Xavier uses 15.
N_inner = 100  # Xavier uses 1e-5 or 1e-7 and 100 iterations for inner loops.

# Static loss function definitions.
g_z = functions.norm_l1()
g_d1 = functions.proj_b2(epsilon=1)  # L2-ball indicator function.
g_d = functions.func()
g_d._eval = lambda Dt: 0
g_d._prox = lambda Dt,_: g_d1._prox(Dt.T, 0).T  # Constraints on lines of D^T.

# Initialization.
Z = Zinit
D = Dinit
objective_z = []
objective_d = []
objective_g = []
tstart = time.time()

# Multi-variate non-convex optimization (outer loop).
for k in np.arange(N_outer):

    # Convex minimization for Z.
    f_z = functions.norm_l2(lambda_=l_d/2., A=D, y=X, tight=False)
    L = l_d * la.norm(np.dot(D.T, D))  # Lipschitz continuous gradient.
    solver = solvers.forward_backward(step=1./L, method='FISTA')
    ret = solvers.solve([f_z, g_z], Z, solver, rtol=None, xtol=1e-6, maxit=N_inner, verbosity='NONE')
    Z = ret['sol']
    objective_z.extend(ret['objective'])
    objective_d.extend(np.zeros(np.shape(ret['objective'])))
    
    # Convex minimization for D.
    f_d = functions.norm_l2(lambda_=l_d/2., A=Z.T, y=X.T, tight=False)
    L = l_d * la.norm(np.dot(Z, Z.T))  # Lipschitz continuous gradient.
    solver = solvers.forward_backward(step=1./L, method='FISTA')
    ret = solvers.solve([f_d, g_d], D.T, solver, rtol=None, xtol=1e-6, maxit=N_inner, verbosity='NONE')
    D = ret['sol'].T
    objective_d.extend(ret['objective'])
    objective_z.extend(np.zeros(np.shape(ret['objective'])))
    
    # Global objective (the indicators are 0).
    objective_g.append(g_z.eval(Z) + f_d.eval(D.T))

print('Elapsed time: %d seconds' % (time.time() - tstart))

## Convergence

In [None]:
plt.figure(figsize=(8,5))
plt.semilogy(np.array(objective_z)[:, 0], label='Z: data term')
plt.semilogy(np.array(objective_z)[:, 1], label='Z: prior term')
#plt.semilogy(np.sum(objective[:,0:2], axis=1), label='Z: sum')
plt.semilogy(np.array(objective_d)[:, 0], label='D: data term')
niter = np.shape(objective_z)[0]
plt.xlim(0, niter-1)
plt.title('Sub-problems convergence')
plt.xlabel('Iteration number (inner loops)')
plt.ylabel('Objective function value')
plt.grid(True); plt.legend(); plt.show()
print('Inner loop: %d iterations' % niter)

plt.figure(figsize=(8,5))
plt.plot(objective_g)
niter = np.shape(objective_g)[0]
plt.xlim(0, niter-1)
plt.title('Global convergence')
plt.xlabel('Iteration number (outer loop)')
plt.ylabel('Objective function value')
plt.grid(True); plt.show()
print('Outer loop: %d iterations\n' % niter)

print('g_z(Z) = %e' % g_z.eval(Z))
print('f_z(Z,D) = %e' % f_z.eval(Z))
print('f_d(D,Z) = %e' % f_d.eval(D.T))
print('g_z(Z) + f_d(D,Z) = %e' % objective_g[-1])

## Solution analysis

### Solution from Xavier

In [None]:
mat = scipy.io.loadmat('data/xavier_ZD.mat')
Zxavier = mat['Z']
Dxavier = mat['D']
print('Elapsed time: %d seconds' % mat['exectime'])

g_z = functions.norm_l1()
f_z = functions.norm_l2(lambda_=l_d/2., A=Dxavier, y=X, tight=False)
f_d = functions.norm_l2(lambda_=l_d/2., A=Zxavier.T, y=X.T, tight=False)

print('g_z(Z) = %e' % g_z.eval(Zxavier))
print('f_z(Z,D) = %e' % f_z.eval(Zxavier))
print('f_d(D,Z) = %e' % f_d.eval(Dxavier.T))
print('g_z(Z) + f_d(D,Z) = %e' % (g_z.eval(Zxavier) + f_d.eval(Dxavier.T)))

### Sparse codes

In [None]:
def sparse_code(Z):
    nnz = np.count_nonzero(Z)
    #nnz = np.sum(np.abs(Z) < 1e-4)
    print('Sparsity of Z: %d non-zero entries out of %d entries, i.e. %.1f%%.' % (nnz, Z.size, 100.*nnz/Z.size))

    plt.figure(figsize=(8,5))
    plt.spy(Z, precision=0, aspect='auto')
    plt.xlabel('N = %d samples' % N)
    plt.ylabel('m = %d atoms' % m)
    plt.show()

sparse_code(Zxavier)
sparse_code(Z)

### Dictionary

In [None]:
def dictionary(D):

    d = np.sqrt(np.sum(D*D, axis=0))
    print('Constraints on D: %s' % np.alltrue(d <= 1+1e-15))

    plt.figure(figsize=(8,5))
    plt.plot(d, 'b.')
    plt.ylim(0.5, 1.5)
    plt.xlim(0, m-1)
    plt.title('Dictionary atom norms')
    plt.xlabel('Atom [1,m]')
    plt.ylabel('Norm [0,1]')
    plt.grid(True); plt.show()
    plt.show()

    plt.figure(figsize=(8,5))
    plt.spy(D, precision=1e-2, aspect='auto')
    plt.xlabel('m = %d atoms' % (m,))
    plt.ylabel('data dimensionality of n = %d' % n)
    plt.show()

    #plt.scatter to show intensity

dictionary(Dxavier)
dictionary(D)

In [None]:
def atoms(D):
    plt.figure(figsize=(8,8))
    Nx = np.ceil(np.sqrt(m))
    Ny = np.ceil(m / float(Nx))
    for k in np.arange(m):
        plt.subplot(Ny, Nx, k)
        img = D[:,k].reshape(Np, Np)
        plt.imshow(img, cmap='gray')  # vmin=0, vmax=1 to disable normalization.
        plt.axis('off')

atoms(Dxavier)
atoms(D)