Permalink
Find file
8037bb5 Nov 27, 2014
65 lines (55 sloc) 2.61 KB
"""
mnist_average_darkness
~~~~~~~~~~~~~~~~~~~~~~
A naive classifier for recognizing handwritten digits from the MNIST
data set. The program classifies digits based on how dark they are
--- the idea is that digits like "1" tend to be less dark than digits
like "8", simply because the latter has a more complex shape. When
shown an image the classifier returns whichever digit in the training
data had the closest average darkness.
The program works in two steps: first it trains the classifier, and
then it applies the classifier to the MNIST test data to see how many
digits are correctly classified.
Needless to say, this isn't a very good way of recognizing handwritten
digits! Still, it's useful to show what sort of performance we get
from naive ideas."""
#### Libraries
# Standard library
from collections import defaultdict
# My libraries
import mnist_loader
def main():
training_data, validation_data, test_data = mnist_loader.load_data()
# training phase: compute the average darknesses for each digit,
# based on the training data
avgs = avg_darknesses(training_data)
# testing phase: see how many of the test images are classified
# correctly
num_correct = sum(int(guess_digit(image, avgs) == digit)
for image, digit in zip(test_data[0], test_data[1]))
print "Baseline classifier using average darkness of image."
print "%s of %s values correct." % (num_correct, len(test_data[1]))
def avg_darknesses(training_data):
""" Return a defaultdict whose keys are the digits 0 through 9.
For each digit we compute a value which is the average darkness of
training images containing that digit. The darkness for any
particular image is just the sum of the darknesses for each pixel."""
digit_counts = defaultdict(int)
darknesses = defaultdict(float)
for image, digit in zip(training_data[0], training_data[1]):
digit_counts[digit] += 1
darknesses[digit] += sum(image)
avgs = defaultdict(float)
for digit, n in digit_counts.iteritems():
avgs[digit] = darknesses[digit] / n
return avgs
def guess_digit(image, avgs):
"""Return the digit whose average darkness in the training data is
closest to the darkness of ``image``. Note that ``avgs`` is
assumed to be a defaultdict whose keys are 0...9, and whose values
are the corresponding average darknesses across the training data."""
darkness = sum(image)
distances = {k: abs(v-darkness) for k, v in avgs.iteritems()}
return min(distances, key=distances.get)
if __name__ == "__main__":
main()