Julia implementation of the paper _"Online Dictionary Learning for Sparse Coding"_

In [1]:
using Pkg
Pkg.add("PyCall") 
Pkg.add("DataStructures")
Pkg.add("SparseArrays")
Pkg.add("ProgressBars")
Pkg.add("DelimitedFiles")
Pkg.add("Plots")
Pkg.add("Images")


[32m[1m  Updating[22m[39m registry at `~/.julia/registries/General`
[32m[1m  Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`
[32m[1m  Updating[22m[39m `~/.julia/environments/v1.2/Project.toml`
[90m [no changes][39m
[32m[1m  Updating[22m[39m `~/.julia/environments/v1.2/Manifest.toml`
[90m [no changes][39m
[32m[1m Resolving[22m[39m package versions...
[32m[1m  Updating[22m[39m `~/.julia/environments/v1.2/Project.toml`
[90m [no changes][39m
[32m[1m  Updating[22m[39m `~/.julia/environments/v1.2/Manifest.toml`
[90m [no changes][39m
[32m[1m Resolving[22m[39m package versions...
[32m[1m  Updating[22m[39m `~/.julia/environments/v1.2/Project.toml`
[90m [no changes][39m
[32m[1m  Updating[22m[39m `~/.julia/environments/v1.2/Manifest.toml`
[90m [no changes][39m
[32m[1m Resolving[22m[39m package versions...
[32m[1m  Updating[22m[39m `~/.julia/environments/v1.2/Project.toml`
[90m [no changes][39m
[32m[1m  Upda

In [2]:
using DataStructures
using SparseArrays
using LinearAlgebra
using ProgressBars
using PyCall
using Plots

In [3]:
# importing LARS from scikit-learn

using PyCall
lars = pyimport("sklearn.linear_model");
Lars = lars.Lars();

In [4]:
function init_dictionary(n::Int, K::Int)
    """
    Initialize the dictionary.
    Args: 
        n: dimension of input signal
        k = number of atoms in the dictionary
    """
    # D must be a full-rank matrix
    D = rand(n, K)
    while rank(D) != min(n, K)
        D = rand(n, K)
    end

    @inbounds for k in 1:K
        D[:, k] ./= norm(@view(D[:, k]))
    end
    return D
end

init_dictionary (generic function with 1 method)

In [5]:
function generate_random_vec(number_of_samples)
    return randn(number_of_samples,1)
end

generate_random_vec (generic function with 1 method)

In [6]:
function dictionary_update(D,A,B,threshold)
    """
    This is the so called algorithm2 in the paper.
    """
    # print("Entering Dictionary Update")
    D_old = ones(size(D))
    iter = 0
    while iter < 100
        for j=1:size(A)[2]
          u = (1/A[j,j])*(B[:,j]-D*A[:,j]) + D[:,j]
          D[:,j] = (1/max(norm(u,2),1))*u
        end
    iter+=1
    end
    return D
end

dictionary_update (generic function with 1 method)

In [7]:
function Algorithm1(data,number_iterations,n_atoms)
    """
    Implements the first algorithm.
    
    Args:
    
    initial_dict: initial dictionary
    number_iterations: number of iterations
    """
    dim_input = size(data)[1]
    threshold = 10
    D = init_dictionary(dim_input,n_atoms)
    A = zeros((n_atoms,n_atoms))
    B = zeros((dim_input,n_atoms))
    alpha = ones((n_atoms,1))
    for t in ProgressBar(1:number_iterations)
        x = generate_random_vec(dim_input)
        Lars.fit(reshape(D*alpha,(dim_input,1)),x)
        alpha = Lars.predict(reshape(alpha,(n_atoms,1)))
        A+= reshape(alpha,(n_atoms,1))*transpose(alpha)
        B+= reshape(x,(dim_input,1))*transpose(alpha)
        #print("A: ",A)
        #print("B: ",B)
        D = dictionary_update(D,A,B,threshold)
    end
    return D
end

Algorithm1 (generic function with 1 method)

In [8]:
# The implementation is referencing the wikipedia page
# https://en.wikipedia.org/wiki/Matching_pursuit#The_algorithm

const default_max_iter = 20
const default_tolerance = 1e-6


function SparseArrays.sparsevec(d::DefaultDict, m::Int)
    SparseArrays.sparsevec(collect(keys(d)), collect(values(d)), m)
end


function matching_pursuit_(data::AbstractVector, dictionary::AbstractMatrix,
                           max_iter::Int, tolerance::Float64)
    n_atoms = size(dictionary, 2)

    residual = copy(data)

    xdict = DefaultDict{Int, Float64}(0.)
    for i in 1:max_iter
        if norm(residual) < tolerance
            return sparsevec(xdict, n_atoms)
        end

        # find an atom with maximum inner product
        products = dictionary' * residual
        _, maxindex = findmax(abs.(products))
        maxval = products[maxindex]
        atom = dictionary[:, maxindex]

        # c is the length of the projection of data onto atom
        a = maxval / sum(abs2, atom)  # equivalent to maxval / norm(atom)^2
        residual -= atom * a

        xdict[maxindex] += a
    end
    return sparsevec(xdict, n_atoms)
end


"""
    matching_pursuit(data::Vector, dictionary::AbstractMatrix;
                     max_iter::Int = $default_max_iter,
                     tolerance::Float64 = $default_tolerance)
Find ``x`` such that ``Dx = y`` or ``Dx ≈ y`` where y is `data` and D is `dictionary`.
```
# Arguments
* `max_iter`: Hard limit of iterations
* `tolerance`: Exit when the norm of the residual < tolerance
```
"""
function matching_pursuit(data::AbstractVector, dictionary::AbstractMatrix;
                          max_iter::Int = default_max_iter,
                          tolerance = default_tolerance)

    if tolerance <= 0
        throw(ArgumentError("`tolerance` must be > 0"))
    end

    if max_iter <= 0
        throw(ArgumentError("`max_iter` must be > 0"))
    end

    if size(data, 1) != size(dictionary, 1)
        throw(ArgumentError(
            "Dimensions must match: `size(data, 1)` and `size(dictionary, 1)`."
        ))
    end

    matching_pursuit_(data, dictionary, max_iter, tolerance)
end


"""
    matching_pursuit(data::AbstractMatrix, dictionary::AbstractMatrix;
                     max_iter::Int = $default_max_iter,
                     tolerance::Float64 = $default_tolerance)
Find ``X`` such that ``DX = Y`` or ``DX ≈ Y`` where Y is `data` and D is `dictionary`.
```
# Arguments
* `max_iter`: Hard limit of iterations
* `tolerance`: Exit when the norm of the residual < tolerance
```
"""
function matching_pursuit(data::AbstractMatrix, dictionary::AbstractMatrix;
                          max_iter::Int = default_max_iter,
                          tolerance::Float64 = default_tolerance)
    K = size(dictionary, 2)
    N = size(data, 2)

    X = spzeros(K, N)

    for i in 1:N
        X[:, i] = matching_pursuit(
            vec(data[:, i]),
            dictionary,
            max_iter = max_iter,
            tolerance = tolerance
        )
    end
    return X
end

matching_pursuit

In [9]:
using DelimitedFiles


data = readdlm("../CIFAR10_data.dlm");
labels = readdlm("../CIFAR10_labels.dlm");

In [10]:
@time D = Algorithm1(data[1:100,:],10,3000);

100.0%┣██████████████████████████████████████████████████████████████┫ 10/10 [02:54<00:00, 0.1 it/s]
175.233193 seconds (34.01 M allocations: 86.633 GiB, 3.51% gc time)


In [17]:
X = matching_pursuit(data[101:200,:],D,max_iter=300);
X

3000×1024 SparseMatrixCSC{Float64,Int64} with 181470 stored entries:
  [1361,    1]  =  -1464.67
  [1393,    1]  =  -136.203
  [1394,    1]  =  1.34941
  [1406,    1]  =  1.07464
  [1411,    1]  =  2.53076
  [1424,    1]  =  1.94482
  [1437,    1]  =  1.92397
  [1438,    1]  =  46.091
  [1447,    1]  =  2.83491
  [1467,    1]  =  2.72766
  [1468,    1]  =  -88.9091
  [1483,    1]  =  -173.137
  ⋮
  [2898, 1024]  =  -9.46408
  [2899, 1024]  =  5.17171
  [2921, 1024]  =  -14.2971
  [2928, 1024]  =  -6.5412
  [2934, 1024]  =  -3.16134
  [2939, 1024]  =  -0.77082
  [2948, 1024]  =  -5.35732
  [2959, 1024]  =  -0.61133
  [2963, 1024]  =  -34.5683
  [2969, 1024]  =  130.107
  [2972, 1024]  =  214.422
  [2978, 1024]  =  -163.678
  [2981, 1024]  =  1.05868

In [21]:
img = (D*X)[70,:];
img2 = data[170,:];

In [22]:
plt = pyimport("matplotlib.pyplot");
plt.imshow(reshape(img2,(32,32)),cmap="gray")
plt.figure()
plt.imshow(reshape(img,(32,32)),cmap="gray")
plt.show()

In [23]:
#writedlm("dlm_files/CIFAR10_dict.dlm",D)