In [1]:
import numpy as np
import os
import pickle

In [2]:
def load_CIFAR_batcha(filename):
    """ load single batch of cifar """
    with open(filename, 'rb') as f:
        datadict = pickle.load(f, encoding = 'latin1')
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype('float')
        Y = np.array(Y)
        return X, Y

In [3]:
def load_CIFAR10(ROOT):
    """ load all of cifar """
    xs = []
    ys = []
    for b in range(1,6):
        f = os.path.join(ROOT, 'data_batch_%d'%(b,))
        X, Y = load_CIFAR_batcha(f)
        xs.append(X)
        ys.append(Y)
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batcha(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte


In [4]:
Xtr, Ytr, Xte, Yte = load_CIFAR10('cs231n/datasets/cifar-10-batches-py/')

In [5]:
Yte.shape

(10000,)

In [6]:
Xte.shape

(10000, 32, 32, 3)

In [7]:
Ytr.shape

(50000,)

In [8]:
Ytr

array([6, 9, 9, ..., 9, 1, 1])

In [9]:
Xte[0] #one of the 32*32 rgb pic

array([[[ 158.,  112.,   49.],
        [ 159.,  111.,   47.],
        [ 165.,  116.,   51.],
        ..., 
        [ 137.,   95.,   36.],
        [ 126.,   91.,   36.],
        [ 116.,   85.,   33.]],

       [[ 152.,  112.,   51.],
        [ 151.,  110.,   40.],
        [ 159.,  114.,   45.],
        ..., 
        [ 136.,   95.,   31.],
        [ 125.,   91.,   32.],
        [ 119.,   88.,   34.]],

       [[ 151.,  110.,   47.],
        [ 151.,  109.,   33.],
        [ 158.,  111.,   36.],
        ..., 
        [ 139.,   98.,   34.],
        [ 130.,   95.,   34.],
        [ 120.,   89.,   33.]],

       ..., 
       [[  68.,  124.,  177.],
        [  42.,  100.,  148.],
        [  31.,   88.,  137.],
        ..., 
        [  38.,   97.,  146.],
        [  13.,   64.,  108.],
        [  40.,   85.,  127.]],

       [[  61.,  116.,  168.],
        [  49.,  102.,  148.],
        [  35.,   85.,  132.],
        ..., 
        [  26.,   82.,  130.],
        [  29.,   82.,  126.],
        [ 

In [10]:
Xtr_rows = Xtr.reshape(Xtr.shape[0], 32*32*3)
Xtr_rows.shape

(50000, 3072)

In [11]:
Xte_rows = Xte.reshape(Xte.shape[0], 32*32*3)
Xte_rows.shape

(10000, 3072)

In [12]:
Xtr_rows - Xte_rows[0,:]

array([[ -99.,  -50.,   14., ...,  102.,   25.,  -38.],
       [  -4.,   65.,  138., ...,  122.,   66.,   34.],
       [  97.,  143.,  206., ...,   59.,   19.,  -26.],
       ..., 
       [-123.,   66.,  186., ...,   -9.,  -36.,  -60.],
       [  31.,   99.,  191., ...,  174.,  123.,   61.],
       [  71.,  117.,  190., ...,  142.,   96.,   51.]])

In [13]:
Xte_rows[0,:].shape

(3072,)

In [14]:
a = np.full((2,3),7)
a


array([[7, 7, 7],
       [7, 7, 7]])

In [15]:
b = [1,2,3]

In [16]:
a-b

array([[6, 5, 4],
       [6, 5, 4]])

In [17]:
np.sum(a-b,axis=1)

array([15, 15])

In [18]:
class NearestNeighbor(object):
    def __init__(self):
        pass
    def train(self, X, y):
        """ X is N x D where each row is an example. Y is l-dimension of size N"""
        # the nearest neighbor classifire simply remembers all the training data
        self.Xtr = X
        self.ytr = y
        
    def predict(self, X):
        """ X is N x D where each row is an example we wish to predict label for"""
        num_test = X.shape[0]
        Ypred = np.zeros(num_test, dtype=self.ytr.dtype)
        
        #loop over all test rows
        for i in range(num_test):
            distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
            min_index = np.argmin(distances) #get the index with smallest distance
            Ypred[i] = self.ytr[min_index] #predict the label of the nearest example
            print('No.{} pic maybe {}: {} '.format(i,Ypred[i],Ypred[i]==Yte[i]))
        return Ypred
    

In [19]:
nn = NearestNeighbor()
nn.train(Xtr_rows, Ytr)

In [20]:
Yte_predict = nn.predict(Xte_rows)

No.0 pic maybe 4: False 
No.1 pic maybe 8: True 
No.2 pic maybe 1: False 
No.3 pic maybe 0: True 
No.4 pic maybe 4: False 
No.5 pic maybe 6: True 
No.6 pic maybe 4: False 
No.7 pic maybe 2: False 
No.8 pic maybe 4: False 
No.9 pic maybe 8: False 
No.10 pic maybe 8: False 
No.11 pic maybe 8: False 


KeyboardInterrupt: 

In [None]:
print ('accuracy: %f' % (np.mean(Yte_predict == Yte)))

accuracy: 38.6%

In [None]:
#L2 distance
distances = np.sqrt(np.sum(np.square(self.Xtr - X[i,:]), axis=1))