# Non-Negative Matrix Factorization

## Implementing NMF for Image Analysis

This notebook can use both Python2 or 3 as a backend.
To switch easily from Python3 to Python2, use the conda [environments](https://conda.io/docs/user-guide/tasks/manage-environments.html).

In [None]:
import os
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
%matplotlib inline

### Load data

In [None]:
image_dir = "face/"
files = os.listdir(image_dir)
n = len(files)
print("Loading " + str(n) + " images")

In [None]:
imgs = []
for i in range(n):
    with Image.open(image_dir + files[i]) as img:
        imgs.append(img.copy())
print(imgs[1].mode)  #gray scale images

In [None]:
fig, ax = plt.subplots(figsize=(3, 3))
im = ax.imshow(imgs[3], cmap='gray', interpolation='none')
plt.colorbar(im, ax=ax)
plt.title('A face', size=16)
plt.show()

### Build $X$

In [None]:
# All images have the same shape

img0 = imgs[0]
width, height = img0.size
print("image shape: ", (width, height))

# Compute input matrix X
X_list = [np.ravel(imgs[i].getdata()) for i in range(n)]
X = np.array(X_list, dtype=np.float32)  # change list to array 

print("X shape: ", X.shape)

Each row in $X$ represents an image, $2429$ images, $19\times 19=361$ each

### Initialize $U$ and $V$

In [None]:
rank = 40
lam = 5e-1 # lambda value
# ----rand. init. in [0, 1]
print "rand. init."
U = np.random.random((rank, X.shape[0]))
V = np.random.random((rank, X.shape[1]))
#  ----init. to be constant matrix, 
# c = .5 
# print "consant init., c:", c
# U = c*np.ones((rank, X.shape[0]))
# V = c*np.ones((rank, X.shape[1]))

### Run the Projected  ALS algorithm

In [None]:
nm_iterations = 200
for _ in range(nm_iterations):
    V = np.linalg.lstsq(np.dot(U, U.T) + lam * np.identity(rank),
                        np.dot(U, X))[0]
    V = np.maximum(V, 0.)

    U = np.linalg.lstsq(np.dot(V, V.T) + lam * np.identity(rank),
                        np.dot(V, X.T))[0]
    U = np.maximum(U, 0.)

### Analyze the results

In [None]:
print "V shape:", V.shape
print "U shape:", U.shape
print "V.max:", V.max()
print "U.max:", U.max()

#### Let's plot some image representations

In [None]:
fig, ax = plt.subplots(figsize=(15, 7))
im = ax.matshow((U[:, ::40].T))
plt.xlabel('Weights', size=20)
plt.ylabel('Image representations', size=20)
plt.colorbar(im, ax=ax)
plt.show()

#### Plot some of the base images

In [None]:
fig, ax = plt.subplots(figsize=(16, 7))
im = ax.matshow(V)
plt.xlabel('Normal representation', size=20)
plt.ylabel('Base images', size=20)
#plt.colorbar(im, ax=ax)
plt.show()

#### Plot a few base images after reshaping

In [None]:
base_images = [x.reshape(-1, width) for x in V]
fig, ax = plt.subplots((len(base_images) + 4) / 5, 5,
                       figsize=(12, (len(base_images) + 4) / 5 * 3))
for i in range(len(ax)):
    for j in range(len(ax[i])):
        if i * len(ax[0]) + j >= len(base_images):
            break
        ax[i][j].imshow(base_images[i * len(ax[0]) + j], cmap='gray')
        ax[i][j].set_title(i * len(ax[0]) + j, size=16)

plt.suptitle('The base images', size=20)
fig.tight_layout(rect=[0, 0, 1, 0.97])
plt.show()