In [15]:
#*******************This file is to classify hand-written numbers using low-rank approximation************************
#This is a course task from Umich Course EECS551, some functions such as imshow3 are given by Prof. Fessler
#The purpose of this file is only for learning and discussion
#Haowei Xiang
#Date:10/26

using Plots
using Interact
#pyplot()
plotly();

# read the MNIST data file for 0 to 9 digits
# download from web if needed        website is("http://cis.jhu.edu/~sachin/digit/data0")

nx = 28 # original image size
ny = 28
nrep = 1000

#frrom 0-9
nnumbers=10

#initialize and read data
data=zeros(nx,ny,nrep,nnumbers);
for i=1:10
    fp = open("data$(i%10)", "r")
    data[:,:,:,i] = read(fp, UInt8, (nx,ny,nrep))
    close(fp)
end

display(size(data[:,:,:,1]))

[1m[36mINFO: [39m[22m[36mPrecompiling module Interact.
[39m

(28, 28, 1000)

In [2]:
# function to display mosaic of multiple images
imshow3 = (x) -> begin
    tmp = permutedims(x, [1, 3, 2])
    tmp = reshape(tmp, :, ny)
    heatmap(1:size(tmp,1), 1:ny, tmp,
        xtick=[1,nx], ytick=[1,ny], yflip=true,
        color=:grays, transpose=true, aspect_ratio=1)
end

(::#1) (generic function with 1 method)

In [21]:
# look at a couple of the images
imshow3(cat(3, data[:,:,44,3], data[:,:,3,4],data[:,:,3,7]))

In [27]:
# use some data for training, and some for test
ntrain = 100
ntest = nrep - ntrain

#seperate train and test data
train=data[:,:,1:ntrain,:];
test=data[:,:,(ntrain+1):end,:];

#reshape 4-D originnal nx*ny*n*nnumbers image matrix to 3-D size (nx*ny)*n*nnumbers
train_2D=reshape(train,nx*ny,ntrain,nnumbers);
test_2D=reshape(test,nx*ny,ntest,nnumbers);

#using SVD to get U,s,V for all 0-9
u=zeros(nx*ny,ntrain,nnumbers);#using thin SVD let U be m*n matrix
s=zeros(min(ntrain,nx*ny),nnumbers);
v=zeros(ntrain,ntrain,nnumbers);
for i=1:nnumbers
    u[:,:,i],s[:,i],v[:,:,i]=svd(train_2D[:,:,i])
end

#plot singular values
klist=[1,2,3,4,5,6,7,8,9,10];
@manipulate for i = klist
    plot(1:ntrain, s[:,i],
        marker =:circle, 
        label = "singular value $(i%10)", 
        ylabel = "singular value", 
        xlabel = "k",
    )
end

In [37]:
# look at mean image from each class just to get a sense of things
mean_train=zeros(nx,ny,nnumbers);

for i=1:nnumbers
    mean_train[:,:,i]=squeeze(mean(train[:,:,:,i],3),3);
end

imshow3(cat(3, mean_train[:,:,1:nnumbers]))

In [39]:
# Estimate subspaces for each digit class from training data
#set rank to be 5
k_lowrank=5

#translate 2-D basis U to basis in 3-D
basis_3D=zeros(nx,ny,k_lowrank,nnumbers)
for j=1:nnumbers
    for i=1:k_lowrank
        basis_3D[:,:,i,j]=reshape(u[:,i,j],nx,ny)
    end
end

#show basis for both 0 to 9
imshow3(cat(3, basis_3D[:,:,1,1:10]))

In [67]:
#function classifier: Given test_2D, to classify
function classifier_number(test_2D,Q0,k_lowrank,ntest)
    norm_test=zeros(ntest,nnumbers,nnumbers);
    test_results=zeros(ntest,nnumbers);
    
    # the ith and ith vector's distance to kth coordinate
    for i=1:nnumbers
        for j=1:ntest
            for k=1:nnumbers
                norm_test[j,i,k]=vecnorm(test_2D[:,j,i]-Q0[:,:,k]*(Q0[:,:,k]'*test_2D[:,j,i]))
            end
        end
    end
    
    for i=1:nnumbers
        for j=1:ntest
            temp=minimum(norm_test[j,i,:])
            for k=1:nnumbers
                if temp==norm_test[j,i,k]
                    test_results[j,i]=k;
                end
            end
        end
    end
    return test_results
end   

classifier_number (generic function with 1 method)

In [77]:
#core part: design a classifier
# Classify all the test data based on your subspace estimates
# and count number of misclassified digits
# reshape() and vecnorm() and mapslices() probably useful here
test_results=zeros(ntest,nnumbers)

#Basis is Q=Uk
k_lowrank=3;
Q0=u[:,1:k_lowrank,1:nnumbers];

#testing case
test_results=classifier_number(test_2D,Q0,k_lowrank,ntest);
correct_test=zeros(ntest,nnumbers);

#there may be a more efficient and concise way to note the misclassiefird example
bad=zeros(ntest,nnumbers);

for i=1:nnumbers
    k=1;
    for j=1:ntest
        if test_results[j,i]==i
            correct_test[j,i]=1
        else
            bad[k,i]=j;
            k=k+1;
        end
    end
end

for i=1:nnumbers
    println("correctness for $(i%10) is : $(sum(correct_test[:,i]) / ntest)")
#    display(sum(correct_test[:,i]) / ntest)
end


correctness for 1 is : 0.9855555555555555
correctness for 2 is : 0.8866666666666667
correctness for 3 is : 0.8166666666666667
correctness for 4 is : 0.7366666666666667
correctness for 5 is : 0.7933333333333333
correctness for 6 is : 0.9088888888888889
correctness for 7 is : 0.8633333333333333
correctness for 8 is : 0.7711111111111111
correctness for 9 is : 0.8666666666666667
correctness for 0 is : 0.9722222222222222


In [116]:
#show some bad case
badint=convert(Matrix{Int},bad)

900×10 Array{Int64,2}:
 166    7   2   5   2   10   17   1   8   13
 245   14   3   6   7   48   20   2  21   34
 353   26   8   9  11   57   28   3  28   43
 401   44  12  11  12   61   30   7  37   83
 483   53  15  13  15   74   35   8  40  122
 554   54  16  14  21   75   65  11  46  162
 559   90  17  17  24  105   69  26  49  166
 604  108  18  19  25  137   87  42  50  217
 673  114  24  29  26  138   93  53  56  293
 704  132  27  34  56  153  109  59  57  389
 755  142  30  36  60  168  126  64  67  418
 891  161  36  39  64  180  133  69  73  473
 899  168  43  42  65  192  137  72  90  500
   ⋮                     ⋮                  
   0    0   0   0   0    0    0   0   0    0
   0    0   0   0   0    0    0   0   0    0
   0    0   0   0   0    0    0   0   0    0
   0    0   0   0   0    0    0   0   0    0
   0    0   0   0   0    0    0   0   0    0
   0    0   0   0   0    0    0   0   0    0
   0    0   0   0   0    0    0   0   0    0
   0    0   0   0   0    0    0 

In [126]:
imshow3(cat(3,data[:,:,badint[1:10,4],4]))


In [127]:
imshow3(cat(3,data[:,:,badint[1:10,7],7]))

In [130]:
imshow3(cat(3,data[:,:,badint[90:100,8],8]))