In [1]:
import struct
from array import array
from collections import defaultdict
import sys
from random import sample, randint
import numpy as np
import scipy.spatial
import math
import operator
import matplotlib.pyplot as plt

from scipy.stats import multivariate_normal

In [2]:
def load(path_img, path_lbl):
    with open(path_lbl, 'rb') as file:
        magic, size = struct.unpack(">II", file.read(8))
        if magic != 2049:
            raise ValueError('Magic number mismatch, expected 2049,'
                             'got {}'.format(magic))

        labels = array("B", file.read())

    with open(path_img, 'rb') as file:
        magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
        if magic != 2051:
            raise ValueError('Magic number mismatch, expected 2051,'
                             'got {}'.format(magic))

        image_data = array("B", file.read())

    images = []
    for i in range(size):
        images.append([0] * rows * cols)

    for i in range(size):
        images[i][:] = image_data[i * rows * cols:(i + 1) * rows * cols]

    return images, labels

In [3]:
raw_train_images, raw_train_labels = load("data/hw3/train-images-idx3-ubyte", "data/hw3/train-labels-idx1-ubyte")
test_images, test_labels = load("data/hw3/t10k-images-idx3-ubyte", "data/hw3/t10k-labels-idx1-ubyte")

# raw_train_images = np.divide(np.array(raw_train_images), 255.0)
# test_images = np.divide(np.array(test_images), 255.0)

In [4]:
total_size = len(raw_train_images)
train_size = 50000
validation_size = total_size - train_size

In [5]:
indexs = sample(range(total_size), train_size)
indexs.sort()
train_data = np.array(raw_train_images)[indexs]
train_labels = np.array(raw_train_labels)[indexs]
validation_indexs = [x for x in range(total_size) if x not in indexs]
validation_data = np.array(raw_train_images)[validation_indexs]
validation_labels = np.array(raw_train_labels)[validation_indexs]

In [6]:
train_arrays = [[] for i in range(10)]
for i in range(train_size):
    train_arrays[train_labels[i]].append(train_data[i])

In [7]:
means = [np.average(np.array(train_arrays[i]).T, axis=1) for i in range(10)]

In [8]:
c = 3000
covariances = []
for i in range(10):
    covariances.append(np.cov(np.array(train_arrays[i]).T) + c * np.eye(784))

In [9]:
print np.array(means).shape
print np.array(covariances).shape

(10, 784)
(10, 784, 784)


In [10]:
rv = []
for i in range(10):
    rv.append(multivariate_normal(means[i], covariances[i]))

In [11]:
def judge(x):
    result = {}
    for i in range(10):
        result[i] = rv[i].logpdf(x)
    result_x = sorted(result.items(), key=operator.itemgetter(1))
    return result_x[-1][0], result_x[-1][1] - result_x[-2][1]

In [26]:
count = 0
misclassied_images = []
misclassied_labels = []
misclassied_prediction = []
for i in range(len(test_images)):
# for i in range(10):
    prediction, diff = judge(test_images[i])
    if prediction != test_labels[i]:
        count += 1
        misclassied_images.append(test_images[i])
        misclassied_labels.append(test_labels[i])
        misclassied_prediction.append(prediction)
        print "    ", diff
    else:
        print diff
count * 1.0 / len(test_images)

58.2707134811
69.2877908381
93.6186926838
79.6819110475
14.2287554459
90.4904815678
25.5375880178
36.0879237025
52.3707357295
23.9222884854
103.78366299
69.1661110022
32.1921917158
89.5896563048
95.6846202232
12.0329638486
43.0767911589
89.1264444753
2.82019525983
21.0261517956
32.2737909249
52.5481092633
84.1532838826
32.7532197571
12.4936468979
107.692663734
32.2375835646
23.4646869556
62.7843698131
93.8029985754
32.2264267158
88.6149847833
37.5570846176
8.69252984329
45.1661136834
69.0707175689
52.3316250575
99.9216575695
23.8508584553
83.7474402023
95.9815355957
20.9369274327
19.4706718002
42.9079031554
23.7179487315
13.7767945732
72.265842295
64.8845767567
16.8343650267
49.6837715669
62.2049457922
68.8716363173
51.4529236504
18.3087240544
116.57453639
53.4542759715
59.6381138873
90.8510040824
40.7299397425
2.87114850535
93.5475332297
20.9827607315
28.6364089158
10.4427285984
89.5107525003
11.0236103522
42.171607792
34.1225791939
29.2858305643
112.161830915
122.770127227
105.050016

0.0433

In [30]:
def pr(x, y):
    denominator = 0.0
    for i in range(10):
        denominator += math.exp(x[i] - x[y])
    return 1.0 / denominator
    
    
def show(image, name):
    data = []
    pos = 0
    for i in range(28):
        temp = []
        for j in range(28):
            temp.append(image[pos])
            pos += 1
        data.append(temp)
    data = np.divide(data, 255.0)
    plt.figure(1)
    plt.imshow(data, cmap = plt.get_cmap('gray'))
    plt.savefig("hw3_f" + str(name) + ".png")
    
    
def showimg(num):
    x = []
    for i in range(10):
        temp = rv[i].logpdf(misclassied_images[num])
        print temp
        x.append(temp)

    for i in range(10):
        print pr(x, i)
    print misclassied_labels[num]
    print misclassied_prediction[num]
    show(misclassied_images[num], num)

In [31]:
showimg(0)

-4151.31782831
-4117.49110453
-4096.9463359
-4090.66544866
-4086.22364417
-4128.0448384
-4222.34387252
-4019.11258217
-4050.18382947
-4027.93589959
3.83642909742e-58
1.88229050523e-43
1.5745857002e-34
8.41242044366e-32
7.14450512417e-30
4.91199898034e-48
5.46648426975e-89
0.999852762653
3.20527137638e-14
0.000147237346616
9
7


In [32]:
showimg(1)

-4172.9391351
-4220.6596437
-4154.94465887
-4111.02018722
-4057.39988407
-4102.86080252
-4230.12106309
-4016.94580562
-4099.00369201
-4016.11101327
5.41843235449e-69
1.02122185168e-89
3.53814303624e-61
4.21629554725e-42
8.16456067166e-19
1.47403161078e-38
7.94470667334e-94
0.302632708703
6.97635555104e-37
0.697367291297
7
9


In [33]:
showimg(2)

-4121.44934225
-3984.35618006
-4049.65048523
-4062.18716738
-4035.96411812
-4097.57736074
-4135.84358761
-3993.897051
-4027.143724
-4018.72752125
2.89177799286e-60
0.999928150923
4.39554126809e-29
1.57906671712e-34
3.86301214727e-23
6.7396103809e-50
1.6211515571e-66
7.18490769656e-05
2.61561951291e-19
1.18218373181e-15
7
1


In [34]:
showimg(3)

-4198.33987343
-4483.56585078
-4153.53066668
-4150.28403067
-4167.60467053
-4151.8224122
-4314.46293292
-4245.20743096
-4118.54862212
-4147.50939696
2.22382113759e-35
2.9856011722e-159
6.41935040135e-16
1.65001230999e-14
4.95711360137e-22
3.54304426429e-15
8.23182054504e-86
9.83508766823e-56
1.0
2.64542403175e-13
9
8


In [35]:
showimg(4)

-4072.45746032
-3980.60809779
-4053.17623046
-4042.80699934
-4019.46121942
-4059.22463316
-4077.73188357
-3990.82883341
-4032.42963329
-3992.19601969
1.28916601275e-40
0.999954317114
3.04817978588e-32
9.71271024566e-28
1.3374693248e-17
7.19867654299e-35
6.60170221784e-43
3.64058418132e-05
3.12013399296e-23
9.27704411259e-06
7
1


In [36]:
showimg(5)

-4149.58474557
-4123.35664557
-4093.75776622
-4083.98228218
-4079.78867824
-4083.9562275
-4128.7655091
-4112.15052495
-4080.80955962
-4085.99832937
3.49979418125e-31
8.60521263668e-20
6.15731792997e-07
0.0108350358582
0.717942076669
0.0111210491069
3.85231704741e-22
6.33166539494e-15
0.258658203181
0.00144301945248
3
4


In [37]:
showimg(6)

-4110.23569102
-4181.06059395
-4135.59725164
-4110.19299616
-4127.62827902
-4030.69483226
-4056.91344634
-4250.95194014
-4068.16163638
-4203.19650011
2.85656396469e-35
4.97713512754e-66
2.76349144526e-46
2.98116557154e-35
7.98617756946e-43
0.999999999996
4.10582640415e-12
2.20573558113e-96
5.35024557641e-17
1.21192999005e-75
6
5


In [38]:
showimg(7)

-4103.01569621
-4384.76397609
-4101.85262347
-4056.38113816
-4117.87423616
-4076.42231494
-4178.63848847
-4151.89444699
-4048.09040673
-4050.91815456
1.32187755925e-24
5.74735458553e-147
4.22967796062e-24
0.000236767699728
4.65810547878e-31
4.68327835047e-13
1.8994586485e-57
7.82383137953e-46
0.943933430433
0.0558298018664
9
8


In [39]:
showimg(8)

-4137.59982506
-4160.74440416
-4075.69293614
-4114.98805951
-4113.6287968
-4123.58841366
-4068.29319058
-4188.51241283
-4117.34937632
-4209.71202029
7.94777637329e-31
7.05801615105e-41
0.000611034715609
5.25303921876e-21
2.04517717938e-20
9.66772510733e-25
0.999388965284
6.15445651354e-53
4.95340434743e-22
3.82223311795e-62
4
6


In [40]:
showimg(9)

-4112.08976495
-4026.26025712
-4075.31873714
-4054.76027648
-4122.5945178
-4070.21965137
-4069.15530057
-4164.58676844
-4028.06150874
-4143.22660602
4.55363375138e-38
0.858301226383
4.24436349029e-22
3.59947025365e-13
1.24796354359e-42
6.95532427163e-20
2.01631814803e-19
7.23101333947e-61
0.141698773617
1.3670985989e-51
8
1
