# Online Matrix Factorization

Code taken from  
http://nbviewer.jupyter.org/github/gpeyre/numerical-tours/blob/master/python/inverse_5_inpainting_sparsity.ipynb

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

from nt_toolbox.signal import load_image, imageplot, snr
from nt_toolbox.general import clamp
import matplotlib.pyplot as plt
import numpy as np

import warnings
warnings.filterwarnings('ignore')

Here we consider inpainting of damaged observation without noise.

In [None]:
img_size = 256
#f0 = load_image('image.jpg', img_size)
f0 = load_image('lena.bmp', img_size)

plt.figure(figsize = (6,6))
imageplot(f0, 'Image f_0')

We construct a mask $\Omega$ made of random pixel locations.  
The damaging operator put to zeros the pixel locations $x$ for which $\Omega(x)=1$.  
The damaged observations reads $y = \Phi f_0$.

In [None]:
from numpy import random

rho = .7 # percentage of removed pixels
Omega = np.zeros([img_size, img_size])
sel = random.permutation(img_size**2)
np.ravel(Omega)[sel[np.arange(int(rho*img_size**2))]] = 1

Phi = lambda f, Omega: f*(1-Omega)

y = Phi(f0, Omega)

plt.figure(figsize = (6,6))
imageplot(y, 'Observations y')

### Algorithm 1

Dictionary initialization inspired from  
http://nbviewer.jupyter.org/github/gpeyre/numerical-tours/blob/master/matlab/sparsity_4_dictionary_learning.ipynb

In [None]:
w = 10   # Width of the patches
m = w*w  # Size of the signal to be sparse coded
k = 2*m  # Number of atoms in the dictionary (overcomplete)

Generate a random patch in the damaged image

In [None]:
def random_patch(image, width, n_patches=1):
    img_shape = image.shape
    # Upper left corners of patches
    rows = np.random.randint(0, img_shape[0]-width, n_patches)
    cols = np.random.randint(0, img_shape[1]-width, n_patches)
    
    patches = np.zeros((n_patches, width, width))
    for i in range(n_patches):
        patches[i] = image[
            rows[i]:rows[i]+width,
            cols[i]:cols[i]+width
        ]
    return patches

def plot_dictionary(D):
    assert len(D.shape) == 3
    assert D.shape[1] == D.shape[2]
    n_patches = D.shape[0]
    patch_size = D.shape[1]
    n = int(np.ceil(np.sqrt(n_patches))) # Size of the square in number of patches

    # Pad the images
    pad_size = 1
    missing_patches = n ** 2 - n_patches

    padding = (((0, missing_patches),
                (pad_size, pad_size), (pad_size, pad_size)))
    D = np.pad(D, padding, mode='constant', constant_values=1)
    padded_patch_size = patch_size + 2*pad_size
    D = D.reshape(n,n,padded_patch_size,padded_patch_size)
    D = D.transpose(0,2,1,3) # Needed for the reshape
    big_image_size = n*padded_patch_size
    D = D.reshape(big_image_size, big_image_size)
    imageplot(D)

### Algorithm 2 Dictionary update
From "Online Learning for Matrix Factorization and Sparse Coding"

In [None]:
def update_dictionary(D, A, B):
    '''
    Update the dictionary column by column.
    Denoting k the number of atoms in the dictionary and m the size of the signal, we have:
    
    Args:
        D: dictionary of size (m,k)
        A: Matrix of size (k,k)
        B: Matrix of size (m,k)
    Returns:
        D: Updated dictionary of size (m,k)
    '''
    (m,k) = D.shape
    assert A.shape == (k,k)
    assert B.shape == (m,k)
    
    for j in range(k):        
        uj = (B[:,j]-np.dot(D,A[:,j])) + D[:,j]
        if A[j,j] != 0:
            uj /= A[j,j]
        else:
            # TODO: What to do when A[j,j] is 0 ?
            pass
        D[:,j] = 1/max(np.linalg.norm(uj),1)*uj
    return D

In [None]:
def evaluate(Y_test, D, model):
    alpha = model.fit(D, Y_test).coef_
    error = np.linalg.norm(Y_test - np.dot(D,alpha.T))
    #score = model.score(D, Y_test)
    return error

### Algorithm 1 Online dictionary learning
From "Online Learning for Matrix Factorization and Sparse Coding"

In [None]:
import time
from tqdm import tqdm
from sklearn import linear_model

# Initialize variables
T = 100 # Number of iterations
lambd = 0.1 # L1 penalty coefficient for alpha
# LARS-Lasso from LEAST ANGLE REGRESSION, Efron et al http://statweb.stanford.edu/~tibs/ftp/lars.pdf
lasso = linear_model.Lasso(lambd, fit_intercept=False) # TODO: use lars instead of lasso

D = random_patch(f0, w, n_patches=k) # Initialize dictionary with k random atoms
D = D.reshape(k, m).T # Reshape each atom to column vector
# TODO: normalize atom to unit norm as sparsity_4_dictionary_learning ?
A = np.zeros((k,k))
B = np.zeros((m,k))

# Evaluation data initialization
sparsity = []
error = []
n = 20*k # Number of patch to take for evaluation
Y_test = random_patch(f0, w, n_patches=n)
Y_test = Y_test.reshape((n,m)).T # Reshape each patch to column vector


plt.figure(figsize=(8,12))
plot_dictionary(D.T.reshape(k, w, w))

start = time.time()
for t in tqdm(range(T)):
    x = random_patch(f0, w, n_patches=1).reshape((m,1)) # Draw 1 random patch as column vector
    alpha = lasso.fit(D, x).coef_.reshape((k,1)) # Get the sparse coding # TODO: try with lasso.sparse_coef_
    A += np.dot(alpha,alpha.T)
    B += np.dot(x,alpha.T)
    D = update_dictionary(D, A, B)
    
    if t%10 == 0:
        # Evaluation:
        error.append(evaluate(Y_test, D, lasso))
        sparsity.append(np.sum(alpha!=0))#/alpha.shape[0]
end = time.time()

print('Time elapsed: %.3f s' % (end-start))
plt.figure(figsize=(8,12))
plot_dictionary(D.T.reshape(k, w, w))
plt.figure(figsize=(8,12))
plt.plot(error)
plt.show()