# Sparse Coding

The sparse coding model follows:
$$I = f(UR) + \lambda\Vert R\Vert_1 $$
where $U$ is overcomplete and the $L_1$ norm constrains the sparsity of $R$

In [1]:
using MAT
using LinearAlgebra
using Statistics
using PyPlot

## load image patches 

In [11]:
function read_images(filepath::String)
    vars = matread(filepath)
    return vars["IMAGES"]
end

read_images (generic function with 1 method)

In [12]:
function crop_patches(data::Array{<:Float64}, patch_size::Int64, n::Int64, border=4)
    num_imgs = size(data)[3] 
    X = zeros(n * num_imgs, patch_size, patch_size)
    img_size = size(data)[1]  
    for i = 1:num_imgs
        for j = 1:n
            x, y = rand(border:(img_size - border - patch_size)), rand(border:(img_size - border - patch_size))
            X[(i-1) * n + j, :, :] = data[x:(x+patch_size-1), y:(y+patch_size-1), i]
        end
    end
    return X
end

crop_patches (generic function with 2 methods)

In [None]:
data = read_images("../data/IMAGES.mat");

In [None]:
X = crop_patches(data, 16, 2000);

## ISTA

In [13]:
function ReLU(x)
    return max(x, 0);
end

ReLU (generic function with 1 method)

In [14]:
function shrinkage(x, λ::Float64)
    return ReLU(x - λ) - ReLU(-1 .* x - λ);
end

shrinkage (generic function with 1 method)

In [17]:
function ista(I::Array{<:Float64}, K::Int64, U::Array{<:Float64}, E::Int64, λ::Float64, η::Float64)
    batch_size = size(I)[1];
    R = zeros((batch_size, K));
    for e = 1:E
        # calculate gradient
        Î = R * U;
        ∇ = -2.0 * ((I - Î) * U');
        R = R - η * ∇;
        # shrinkage
        R = shrinkage(R, λ);
    end
    return R;
end

ista (generic function with 2 methods)

## sparse coding model

In [18]:
function normalize_mat!(mat::Array{<:Float64})
    mat = mat ./ sqrt.(sum(mat .^ 2, dims=2));
    return mat;
end

normalize_mat! (generic function with 1 method)

In [None]:
# params
N = 2000;
K = 100;
M = 100;
E = 10000;
Eᵣ = 300;
λ = 5e-3;
η = 1e-3;
ηᵣ = 1e-2;
# synaptic weights
U = rand(K, M);
normalize_mat!(U);

In [123]:
for e = 1:E
    # select batch
    I = X[rand(1:size(X)[1]), N];
    # fit R
    R = ista(I, K, U, Eᵣ, λ, ηᵣ);
    # reconstruction
    Î = R * U;
    # TODO: Finish this
    break
end

BoundsError: BoundsError: attempt to access 20000×16×16 Array{Float64,3} at index [11783, 2000]