## PCA

In [145]:
using LinearAlgebra
using Flux.Data.MNIST
using Plots

┌ Info: Recompiling stale cache file /Users/hidehisa/.julia/compiled/v1.0/Plots/ld3vC.ji for Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]
└ @ Base loading.jl:1190


In [46]:
function PCA(X, m)
    d = size(X)[1]
    n = size(X)[2]
    C = zeros((d, d))
    for i in 1:n
        C += X[:, i] * X[:, i]'
    end

    eigs = eigen(C)
    T = zeros((m, d))
    last = d - m + 1
    for (i, j) in enumerate(d:-1:last)
        vec = eigs.vectors[:, j]
        T[i, :] = vec'
    end
    T * X
end

PCA (generic function with 1 method)

In [51]:
imgs = MNIST.images(:test)
X = hcat(float.(vec.(imgs))...)
X_pca = PCA(X, 2)

2×10000 Array{Float64,2}:
 4.21304   5.35283   2.58519   8.85142  …  7.57268  5.95091   8.7352 
 0.94335  -0.531278  2.17549  -3.88506     2.69419  1.43468  -3.66821

## KMeans

In [143]:
function converged(array, tol)
    mean = 0.0
    for i in 1:length(array)
        mean += sum(array[i].^2)
    end
    mean < tol
end


function calc_center(X)
    n = size(X)[2]
    ret = [0, 0]
    for i in 1:n
        ret += X[:, i]
    end
    ret ./ n
end


function KMeans(X, c, tol=1e-5)
    d = size(X)[1]
    n = size(X)[2]
    centers = [rand(2,1) for _ in 1:c]
    before = [[100, 100] for _ in 1:c]
    cpreds = zeros(n)
    while !converged(before - centers, tol)
        for i in 1:n
            norms = [sum((X[:, i] - center).^2) for center in centers]
            label = argmin(norms)
            cpreds[i] = label
        end
        before = centers
        idxs = [findall(cpreds .== i) for i in 1:c]
        centers = [ifelse(X[idx] != [], calc_center(X[:, idx]), before[i]) for (i, idx) in enumerate(idxs)]
    end
    return centers, cpreds
end

KMeans (generic function with 2 methods)

In [146]:
centers, cpreds = KMeans(X_pca, 10)

(Array{Float64,1}[[7.54484, -1.31615], [4.48019, 0.696349], [6.21759, 0.635575], [8.85981, -4.78443], [5.17758, -1.24386], [5.53064, 2.64259], [7.75908, 1.71639], [5.97506, -3.58853], [3.45779, 2.21392], [9.71095, -0.0979869]], [2.0, 5.0, 9.0, 4.0, 5.0, 9.0, 2.0, 2.0, 1.0, 7.0  …  6.0, 10.0, 1.0, 4.0, 6.0, 10.0, 1.0, 7.0, 3.0, 4.0])

## Plotting

In [149]:
p = plot(xlabel="x", ylabel="y", title="Compressed image of MNIST")
colors = [:green, :blue, :red, :yellow, :pink, :black, :gold, :silver, :brown, :purple]
for (i, c) in enumerate(colors)
    idx = findall(cpreds .== i)
    scatter!(X_pca[1, idx], X_pca[2, idx], color=c, label=string(i))
end
png("scatter")

In [150]:
labels = MNIST.labels(:test)

10000-element Array{Int64,1}:
 7
 2
 1
 0
 4
 1
 4
 9
 5
 9
 0
 6
 9
 ⋮
 5
 6
 7
 8
 9
 0
 1
 2
 3
 4
 5
 6

In [151]:
p = plot(xlabel="x", ylabel="y", title="Compressed image of MNIST with true label")
colors = [:green, :blue, :red, :yellow, :pink, :black, :gold, :silver, :brown, :purple]
for (i, c) in enumerate(colors)
    idx = findall(labels .== i)
    scatter!(X_pca[1, idx], X_pca[2, idx], color=c, label=string(i))
end
png("scatter_true")