In [1]:
import util, perceptron, nb, samples

In [2]:
DIGIT_DATUM_WIDTH=28
DIGIT_DATUM_HEIGHT=28
FACE_DATUM_WIDTH=60
FACE_DATUM_HEIGHT=70

In [3]:
def analysis(classifier, guesses, testLabels, testData, rawTestData, printImage):
  """
  This function is called after learning.
  Include any code that you want here to help you analyze your results.
  
  Use the printImage(<list of pixels>) function to visualize features.
  
  An example of use has been given to you.
  
  - classifier is the trained classifier
  - guesses is the list of labels predicted by your classifier on the test set
  - testLabels is the list of true labels
  - testData is the list of training datapoints (as util.Counter of features)
  - rawTestData is the list of training datapoints (as samples.Datum)
  - printImage is a method to visualize the features 
  (see its use in the odds ratio part in runClassifier method)
  
  This code won't be evaluated. It is for your own optional use
  (and you can modify the signature if you want).
  """
  
  # Put any code here...
  # Example of use:
  for i in range(len(guesses)):
      prediction = guesses[i]
      truth = testLabels[i]
      if (prediction != truth):
          print("===================================")
          print("Mistake on example %d" % i) 
          print("Predicted %d; truth is %d" % (prediction, truth))
          print("Image: ")
          print(rawTestData[i])
          break

In [4]:
def basicFeatureExtractorDigit(datum):
  """
  Returns a set of pixel features indicating whether
  each pixel in the provided datum is white (0) or gray/black (1)
  """
  a = datum.getPixels()

  features = util.Counter()
  for x in range(DIGIT_DATUM_WIDTH):
    for y in range(DIGIT_DATUM_HEIGHT):
      if datum.getPixel(x, y) > 0:
        features[(x,y)] = 1
      else:
        features[(x,y)] = 0
  return features

def basicFeatureExtractorFace(datum):
  """
  Returns a set of pixel features indicating whether
  each pixel in the provided datum is an edge (1) or no edge (0)
  """
  a = datum.getPixels()

  features = util.Counter()
  for x in range(FACE_DATUM_WIDTH):
    for y in range(FACE_DATUM_HEIGHT):
      if datum.getPixel(x, y) > 0:
        features[(x,y)] = 1
      else:
        features[(x,y)] = 0
  return features

In [5]:
ITERATIONS = 10
DATASET = 'faces' # 'digits' or 'faces'
TRAIN_PERCENT = 10
TEST_PERCENT = 100
CLASSIFIER = 'Perceptron' # 'Perceptron' or 'NaiveBayes'


In [6]:
classifiers = {'Perceptron': perceptron.Perceptron,
               'NaiveBayes': nb.NaiveBayesClassifier}

features = {'digits': basicFeatureExtractorDigit, 'faces': basicFeatureExtractorFace}

In [7]:
DATUM_WIDTH = DIGIT_DATUM_WIDTH if DATASET == 'digits' else FACE_DATUM_WIDTH
DATUM_HEIGHT = DIGIT_DATUM_HEIGHT if DATASET == 'digits' else FACE_DATUM_HEIGHT
legalLabels = list(range(10)) if DATASET == 'digits' else list(range(2))
numTraining = 5000 if DATASET == 'digits' else 451
numTest = 1000 if DATASET == 'digits' else 150
numTraining = int(numTraining * TRAIN_PERCENT / 100)
numTest = int(numTest * TEST_PERCENT / 100)
getFeatures = features[DATASET]

In [8]:
DATUM_WIDTH

60

In [9]:
if(DATASET=="faces"):
    rawTrainingData = samples.loadDataFile("data/facedata/facedatatrain", numTraining,DATUM_WIDTH,DATUM_HEIGHT)
    trainingLabels = samples.loadLabelsFile("data/facedata/facedatatrainlabels", numTraining)
    rawValidationData = samples.loadDataFile("data/facedata/facedatatrain", numTest,DATUM_WIDTH,DATUM_HEIGHT)
    validationLabels = samples.loadLabelsFile("data/facedata/facedatatrainlabels", numTest)
    rawTestData = samples.loadDataFile("data/facedata/facedatatest", numTest,DATUM_WIDTH,DATUM_HEIGHT)
    testLabels = samples.loadLabelsFile("data/facedata/facedatatestlabels", numTest)
else:
    rawTrainingData = samples.loadDataFile("data/digitdata/trainingimages", numTraining,DATUM_WIDTH,DATUM_HEIGHT)
    trainingLabels = samples.loadLabelsFile("data/digitdata/traininglabels", numTraining)
    rawValidationData = samples.loadDataFile("data/digitdata/validationimages", numTest,DATUM_WIDTH,DATUM_HEIGHT)
    validationLabels = samples.loadLabelsFile("data/digitdata/validationlabels", numTest)
    rawTestData = samples.loadDataFile("data/digitdata/testimages", numTest,DATUM_WIDTH,DATUM_HEIGHT)
    testLabels = samples.loadLabelsFile("data/digitdata/testlabels", numTest)

In [10]:
printImage = util.ImagePrinter(DATUM_WIDTH, DATUM_HEIGHT).printImage

In [11]:
printImage

<bound method ImagePrinter.printImage of <util.ImagePrinter object at 0x7fd590eda050>>

In [12]:
trainingData = list(map(getFeatures, rawTrainingData))
validationData = list(map(getFeatures, rawValidationData))
testData = list(map(getFeatures, rawTestData))

In [13]:
classifier = classifiers[CLASSIFIER](legalLabels, max_iterations=ITERATIONS)

In [14]:
classifier.train(trainingData, trainingLabels, validationData, validationLabels)
guesses = classifier.classify(testData)
correct = [guesses[i] == testLabels[i] for i in range(len(testLabels))].count(True)
print(str(correct), ("correct out of " + str(len(testLabels)) + " (%.1f%%).") % (100.0 * correct / len(testLabels)))

Starting iteration  0 ...
Starting iteration  1 ...
Starting iteration  2 ...
Starting iteration  3 ...
Starting iteration  4 ...
Starting iteration  5 ...
Starting iteration  6 ...
Starting iteration  7 ...
Starting iteration  8 ...
Starting iteration  9 ...
91 correct out of 150 (60.7%).


In [15]:
analysis(classifier, guesses, testLabels, testData, rawTestData, printImage)

Mistake on example 0
Predicted 0; truth is 1
Image: 
                                                            
 ####                                                       
     ###                      #                             
        ####                ## #                            
            ######         #    #                           
                  #########      ######                     
 #                                     ##                   
  #                #         ##          #                  
  #               #         #  #          #                 
  #               #         #   #          #                
 #                 #  ######    #          #                
 #                  ##          #           #               
 #                 #            #            #              
 #             ####                           #             
                                              #             
             #                  