In [1]:
import numpy as np
import os
import pickle
from collections import Counter

In [2]:
def load_CIFAR_batcha(filename):
    """ load single batch of cifar """
    with open(filename, 'rb') as f:
        datadict = pickle.load(f, encoding = 'latin1')
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype('float')
        Y = np.array(Y)
        return X, Y

## X.reshape
X.shape  (10000, 3072(32*32*3))
transpose 转置轴

In [3]:
def load_CIFAR10(ROOT):
    """ load all of cifar """
    xs = []
    ys = []
    for b in range(1,6):
        f = os.path.join(ROOT, 'data_batch_%d'%(b,))
        X, Y = load_CIFAR_batcha(f)
        xs.append(X)
        ys.append(Y)
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batcha(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte


In [4]:
Xtr, Ytr, Xte, Yte = load_CIFAR10('cs231n/datasets/cifar-10-batches-py/')

In [None]:
Yte.shape

In [None]:
Xte.shape

In [None]:
Ytr.shape

In [None]:
Ytr

In [None]:
Xte[0] #one of the 32*32 rgb pic

In [5]:
Xtr_rows = Xtr.reshape(Xtr.shape[0], 32*32*3)
Xtr_rows.shape

(50000, 3072)

In [6]:
Xte_rows = Xte.reshape(Xte.shape[0], 32*32*3)
Xte_rows.shape

(10000, 3072)

In [7]:
Xval_rows = Xtr_rows[:1000,:] #take first 1000 for validation

In [8]:
Xval_rows.shape

(1000, 3072)

In [9]:
Yval = Ytr[:1000]

In [10]:
Yval.shape

(1000,)

In [11]:
Xtr_rows = Xtr_rows[1000:,:]
Ytr = Ytr[1000:]

In [12]:
Xtr_rows.shape

(49000, 3072)

In [13]:
validation_accuracies = []

In [62]:
class NearestNeighbor(object):
    def __init__(self):
        pass
    def train(self, X, y):
        """ X is N x D where each row is an example. Y is l-dimension of size N"""
        # the nearest neighbor classifire simply remembers all the training data
        self.Xtr = X
        self.ytr = y
        
    def predict(self, X, k):
        """ X is N x D where each row is an example we wish to predict label for"""
        num_test = X.shape[0]
        prediction = np.zeros(num_test, dtype=self.ytr.dtype)
        
        #loop over all test rows
        for i in range(num_test):
            distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
            ytrr = self.ytr
            #print(distances.shape)
            #find k min value 
            k_min_index = []
            Ypred = []
            for j in range(k):
                min_index = np.argmin(distances)
                #print(min_index)
                k_min_index.append(min_index) #get the index with smallest distance
                Ypred.append(ytrr[min_index])
                #print(ytrr[min_index])#predict the label of the nearest example
                ytrr = np.delete(ytrr,min_index)
                distances = np.delete(distances,min_index)
                #print(distances.shape)
                #print(ytrr.shape)
            prediction[i] = Counter(Ypred).most_common(1)[0][0]
            
            print('Pic No.{} maybe {}: {} accuracy:{}'.format(i,Counter(Ypred).most_common(1),prediction[i]==Yte[i],np.mean(prediction[:i+1] == Yte[:i+1])))
        return Ypred
    

In [None]:
for k in [1,3,5,10,20,50,100]:
    nn = NearestNeighbor()
    nn.train(Xtr_rows, Ytr)
    Yval_predict = nn.predict(Xval_rows, k=k)

In [None]:
nn = NearestNeighbor()
nn.train(Xtr_rows, Ytr)
Yte_predict = nn.predict(Xte_rows,7)

Pic No.0 maybe [(4, 2)]: False accuracy:0.0
Pic No.1 maybe [(8, 4)]: True accuracy:0.5
Pic No.2 maybe [(8, 5)]: True accuracy:0.6666666666666666
Pic No.3 maybe [(0, 6)]: True accuracy:0.75
Pic No.4 maybe [(4, 3)]: False accuracy:0.6
Pic No.5 maybe [(6, 5)]: True accuracy:0.6666666666666666
Pic No.6 maybe [(6, 3)]: False accuracy:0.5714285714285714
Pic No.7 maybe [(4, 4)]: False accuracy:0.5
Pic No.8 maybe [(4, 2)]: False accuracy:0.4444444444444444
Pic No.9 maybe [(8, 5)]: False accuracy:0.4
Pic No.10 maybe [(0, 5)]: True accuracy:0.45454545454545453
Pic No.11 maybe [(1, 2)]: False accuracy:0.4166666666666667
Pic No.12 maybe [(5, 2)]: True accuracy:0.46153846153846156
Pic No.13 maybe [(6, 2)]: False accuracy:0.42857142857142855
Pic No.14 maybe [(4, 3)]: False accuracy:0.4
Pic No.15 maybe [(8, 3)]: True accuracy:0.4375
Pic No.16 maybe [(3, 2)]: False accuracy:0.4117647058823529
Pic No.17 maybe [(4, 3)]: False accuracy:0.3888888888888889
Pic No.18 maybe [(1, 3)]: False accuracy:0.3684210

Pic No.142 maybe [(6, 6)]: True accuracy:0.3986013986013986
Pic No.143 maybe [(3, 3)]: True accuracy:0.4027777777777778
Pic No.144 maybe [(8, 3)]: True accuracy:0.4068965517241379
Pic No.145 maybe [(9, 2)]: False accuracy:0.4041095890410959
Pic No.146 maybe [(2, 2)]: False accuracy:0.4013605442176871
Pic No.147 maybe [(7, 3)]: False accuracy:0.39864864864864863
Pic No.148 maybe [(7, 3)]: False accuracy:0.3959731543624161
Pic No.149 maybe [(3, 4)]: False accuracy:0.3933333333333333
Pic No.150 maybe [(8, 7)]: True accuracy:0.3973509933774834
Pic No.151 maybe [(4, 3)]: False accuracy:0.39473684210526316
Pic No.152 maybe [(4, 2)]: False accuracy:0.39215686274509803
Pic No.153 maybe [(8, 4)]: False accuracy:0.38961038961038963
Pic No.154 maybe [(0, 6)]: True accuracy:0.3935483870967742
Pic No.155 maybe [(3, 3)]: False accuracy:0.391025641025641
Pic No.156 maybe [(2, 3)]: True accuracy:0.39490445859872614
Pic No.157 maybe [(8, 4)]: False accuracy:0.3924050632911392
Pic No.158 maybe [(5, 2)]:

Pic No.278 maybe [(5, 4)]: False accuracy:0.3655913978494624
Pic No.279 maybe [(5, 3)]: False accuracy:0.36428571428571427
Pic No.280 maybe [(9, 4)]: True accuracy:0.3665480427046263
Pic No.281 maybe [(4, 4)]: False accuracy:0.36524822695035464
Pic No.282 maybe [(2, 4)]: False accuracy:0.36395759717314485
Pic No.283 maybe [(1, 7)]: True accuracy:0.36619718309859156
Pic No.284 maybe [(9, 3)]: False accuracy:0.3649122807017544
Pic No.285 maybe [(9, 2)]: True accuracy:0.36713286713286714
Pic No.286 maybe [(1, 2)]: True accuracy:0.3693379790940767
Pic No.287 maybe [(8, 7)]: False accuracy:0.3680555555555556
Pic No.288 maybe [(3, 4)]: False accuracy:0.36678200692041524
Pic No.289 maybe [(9, 5)]: True accuracy:0.3689655172413793
Pic No.290 maybe [(1, 4)]: True accuracy:0.3711340206185567
Pic No.291 maybe [(2, 3)]: True accuracy:0.3732876712328767
Pic No.292 maybe [(2, 4)]: False accuracy:0.3720136518771331
Pic No.293 maybe [(0, 2)]: False accuracy:0.3707482993197279
Pic No.294 maybe [(6, 3)]

Pic No.414 maybe [(9, 3)]: False accuracy:0.3686746987951807
Pic No.415 maybe [(9, 4)]: True accuracy:0.3701923076923077
Pic No.416 maybe [(6, 3)]: False accuracy:0.36930455635491605
Pic No.417 maybe [(7, 2)]: True accuracy:0.3708133971291866
Pic No.418 maybe [(3, 5)]: True accuracy:0.3723150357995227
Pic No.419 maybe [(7, 3)]: True accuracy:0.3738095238095238
Pic No.420 maybe [(6, 4)]: False accuracy:0.37292161520190026
Pic No.421 maybe [(4, 5)]: False accuracy:0.37203791469194314
Pic No.422 maybe [(8, 6)]: False accuracy:0.37115839243498816
Pic No.423 maybe [(6, 3)]: False accuracy:0.37028301886792453
Pic No.424 maybe [(2, 7)]: True accuracy:0.37176470588235294
Pic No.425 maybe [(8, 5)]: False accuracy:0.37089201877934275
Pic No.426 maybe [(5, 3)]: False accuracy:0.3700234192037471
Pic No.427 maybe [(6, 5)]: False accuracy:0.3691588785046729
Pic No.428 maybe [(7, 2)]: False accuracy:0.3682983682983683
Pic No.429 maybe [(6, 7)]: True accuracy:0.3697674418604651
Pic No.430 maybe [(6, 3

Pic No.550 maybe [(2, 3)]: False accuracy:0.3738656987295826
Pic No.551 maybe [(6, 3)]: False accuracy:0.37318840579710144
Pic No.552 maybe [(8, 5)]: True accuracy:0.3743218806509946
Pic No.553 maybe [(6, 2)]: False accuracy:0.37364620938628157
Pic No.554 maybe [(8, 2)]: False accuracy:0.372972972972973
Pic No.555 maybe [(8, 6)]: False accuracy:0.3723021582733813
Pic No.556 maybe [(6, 3)]: True accuracy:0.3734290843806104
Pic No.557 maybe [(4, 4)]: True accuracy:0.37455197132616486
Pic No.558 maybe [(3, 3)]: True accuracy:0.3756708407871199
Pic No.559 maybe [(2, 3)]: True accuracy:0.3767857142857143
Pic No.560 maybe [(0, 5)]: True accuracy:0.3778966131907308
Pic No.561 maybe [(7, 4)]: True accuracy:0.3790035587188612
Pic No.562 maybe [(4, 3)]: False accuracy:0.3783303730017762
Pic No.563 maybe [(4, 3)]: False accuracy:0.3776595744680851
Pic No.564 maybe [(4, 6)]: False accuracy:0.3769911504424779
Pic No.565 maybe [(3, 4)]: True accuracy:0.37809187279151946
Pic No.566 maybe [(9, 4)]: Tr

In [None]:
print ('accuracy: %f' % (np.mean(Yte_predict == Yte)))

accuracy: 38.6%

In [None]:
#L2 distance
distances = np.sqrt(np.sum(np.square(self.Xtr - X[i,:]), axis=1))