# Nonnegative Matrix Factorization

In [nonnegative matrix factorization (NMF)](https://www.nature.com/articles/44565), a data matrix $\mathbf{X} \in \mathbb{R}^{m \times n}$ with nonnegative entries $x_{ij}$ is approximated by a product of two low-rank matrices $\mathbf{V} \in \mathbb{R}^{m \times r}$ and $\mathbf{W} \in \mathbb{R}^{r \times n}$ with nonnegative entries $v_{ik}$ and $w_{kj}$. The goal is to minimize the squared Frobenius norm,
$$
	L(\mathbf{V}, \mathbf{W}) = \|\mathbf{X} - \mathbf{V} \mathbf{W}\|_{\text{F}}^2 = \sum_i \sum_j \left(x_{ij} - \sum_k v_{ik} w_{kj} \right)^2, \quad v_{ik} \ge 0, w_{kj} \ge 0
$$

Below we implement a majorization-minimization (MM) algorithm with iterative updates,
$$
	v_{ik}^{(t+1)} = v_{ik}^{(t)} \frac{\sum_j x_{ij} w_{kj}^{(t)}}{\sum_j b_{ij}^{(t)} w_{kj}^{(t)}}, \quad \text{where } b_{ij}^{(t)} = \sum_k v_{ik}^{(t)} w_{kj}^{(t)},
$$

$$
w_{kj}^{(t+1)} = w_{kj}^{(t)} \frac{\sum_i x_{ij} v_{ik}^{(t+1)}}{\sum_i b_{ij}^{(t+1/2)} v_{ik}^{(t+1)}}, \quad \text{where } b_{ij}^{(t+1/2)} = \sum_k v_{ik}^{(t+1)} w_{kj}^{(t)}
$$
that drives the objective $L^{(t)} = L(\mathbf{V}^{(t)}, \mathbf{W}^{(t)})$ downhill. Superscript $t$ indicates the iteration number.

## Implement algorithm

Arguments include data $\mathbf{X}$ (each row is vectorized), rank $r$, convergence tolerance, and an optional starting point. A key here is $r \ll n, m$, such that $(\mathbf{VW})\mathbf{W}^T$ involves $4mnr$ flops, while $\mathbf{V}(\mathbf{WW}^T)$ involves $2(m + n)r^2$ flops, so the latter is more computationally efficient. 

In [1]:
function nmf(
    X::Matrix{T},
    r::Int;
    maxiter::Int = 1000, 
    tol::Float64 = 1e-4,
    V::Matrix{T} = rand(T, size(X, 1), r),
    W::Matrix{T} = rand(T, r, size(X, 2))
    ) where T <: AbstractFloat
    # Initialize arrays
    storage_mr = Matrix{Float64}(undef, size(X, 1), r)
    storage_rr = Matrix{Float64}(undef, r, r)
    storage_rn = Matrix{Float64}(undef, r, size(X, 2))
    L = copy(X)
    BLAS.gemm!('N', 'N', -1.0, V, W, 1.0, L)
    loss = abs2(norm(L))
    # Iterative steps
    for i in 1:maxiter
        # Update V .= V .* XWᵗ ./ V(WWᵗ)
        BLAS.gemm!('N', 'T', 1.0, W, W, 0.0, storage_rr)
        BLAS.gemm!('N', 'N', 1.0, V, storage_rr, 0.0, storage_mr)
        V ./= storage_mr
        BLAS.gemm!('N', 'T', 1.0, X, W, 0.0, storage_mr)
        V .*= storage_mr
        # Update W .= W .* VᵗX ./ (VᵗV)W  
        BLAS.gemm!('T', 'N', 1.0, V, V, 0.0, storage_rr)
        BLAS.gemm!('N', 'N', 1.0, storage_rr, W, 0.0, storage_rn)
        W ./= storage_rn
        BLAS.gemm!('T', 'N', 1.0, V, X, 0.0, storage_rn)
        W .*= storage_rn
        # Compare objective
        copyto!(L, X)
        BLAS.gemm!('N', 'N', -1.0, V, W, 1.0, L)
        newloss = abs2(norm(L))
        if abs(newloss - loss) / (abs(loss) + 1) < tol
            break
        end
        loss = newloss
    end
    return V, W, loss
end

nmf (generic function with 1 method)

## Download data
The data from [MIT Center for Biological and Computational Learning (CBCL)](http://cbcl.mit.edu) contains $m = 2,429$ gray-scale face images with $n = 19 \times 19 = 361$ pixels per face. Each image (row) is scaled to have mean and standard deviation 0.25.

In [2]:
using DelimitedFiles
X = readdlm(download("http://Hua-Zhou.github.io/teaching/biostatm280-2018spring/hw/hw2/nnmf-2429-by-361-face.txt"), ' ', Float64)
V0 = readdlm(download("http://Hua-Zhou.github.io/teaching/biostatm280-2018spring/hw/hw2/V0.txt"), ' ', Float64)
W0 = readdlm(download("http://Hua-Zhou.github.io/teaching/biostatm280-2018spring/hw/hw2/W0.txt"), ' ', Float64);


## Display data

In [3]:
using ImageView
imshow(reshape(X[rand(1:size(X, 1)), :], 19, 19))  # PyPlot as an alternative

Dict{String, Any} with 4 entries:
  "gui"         => Dict{String, Any}("window"=>GtkWindowLeaf(name="", parent, w…
  "roi"         => Dict{String, Any}("redraw"=>37: "map(clim-mapped image, inpu…
  "annotations" => 3: "input-2" = Dict{UInt64, Any}() Dict{UInt64, Any} 
  "clim"        => 2: "CLim" = CLim{Float64}(0.0, 1.0) CLim{Float64} 

## Check speed and memory 

In [4]:
using LinearAlgebra
for r in 10:10:50
    @time nmf(X, r; V = V0[:, 1:r], W = W0[1:r, :])
end

  0.770718 seconds (953.68 k allocations: 58.833 MiB, 1.62% gc time, 46.47% compilation time)
  0.831669 seconds (21 allocations: 7.545 MiB)
  1.249642 seconds (21 allocations: 7.975 MiB, 0.98% gc time)
  1.639336 seconds (21 allocations: 8.406 MiB)
  2.173468 seconds (22 allocations: 8.839 MiB)


In [5]:
using Random
Random.seed!(1234)

r = 20
V, W, loss = nmf(X, r)

([0.025731260290675448 0.006434289755739342 … 0.10033271526215089 0.0233267333067549; 0.043500976713445834 0.011915389119817476 … 0.06122288374952517 0.008483063914880531; … ; 0.013144182901897963 0.010685465835828955 … 0.06019095194260951 0.058262261360138456; 5.548850952514377e-10 0.051961163757573356 … 0.004577258754319713 0.034400141294116254], [0.011673074186925739 0.04348766832906184 … 1.0262843459962898e-19 4.178669155575702e-53; 7.7677862319555055e-28 1.8635827754325187e-18 … 5.5469801327076915 4.072040782452289; … ; 0.00653541868937037 3.155625772304931e-15 … 0.021539223908874805 0.00381993217859815; 1.5160915791968106e-9 1.998536650017957e-16 … 0.06258574722950752 1.1274049262483066e-17], 8352.124499186913)

In [6]:
V, W, loss = nmf(X, r; V = ones(size(X, 1), r), W = ones(r, size(X, 2)))

([0.013687348350579975 0.013687348350579975 … 0.013687348350579975 0.013687348350579975; 0.013603286404271707 0.013603286404271707 … 0.013603286404271707 0.013603286404271707; … ; 0.014257981139452211 0.014257981139452211 … 0.014257981139452211 0.014257981139452211; 0.01401498047068391 0.01401498047068391 … 0.01401498047068391 0.01401498047068391], [0.36994329971813567 0.4606433644483843 … 0.5581000847461051 0.4255419042717461; 0.36994329971813567 0.4606433644483843 … 0.5581000847461051 0.4255419042717461; … ; 0.36994329971813567 0.4606433644483843 … 0.5581000847461051 0.4255419042717461; 0.36994329971813567 0.4606433644483843 … 0.5581000847461051 0.4255419042717461], 25297.357820306384)

In [7]:
imshow(reshape(W[rand(1:r), :], 19, 19));

## Discussion 

With different starting points, including different values of ranks, one gets different results. This is in contrast to principal component analysis (PCA), where one gets unique results, regardless of the number of ranks.

In [8]:
versioninfo()

Julia Version 1.6.0
Commit f9720dc2eb (2021-03-24 12:55 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin19.6.0)
  CPU: Intel(R) Core(TM) i9-9880H CPU @ 2.30GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, skylake)
