Permalink
Cannot retrieve contributors at this time
""" | |
mnist_svm | |
~~~~~~~~~ | |
A classifier program for recognizing handwritten digits from the MNIST | |
data set, using an SVM classifier.""" | |
#### Libraries | |
# My libraries | |
import mnist_loader | |
# Third-party libraries | |
from sklearn import svm | |
def svm_baseline(): | |
training_data, validation_data, test_data = mnist_loader.load_data() | |
# train | |
clf = svm.SVC() | |
clf.fit(training_data[0], training_data[1]) | |
# test | |
predictions = [int(a) for a in clf.predict(test_data[0])] | |
num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1])) | |
print "Baseline classifier using an SVM." | |
print "%s of %s values correct." % (num_correct, len(test_data[1])) | |
if __name__ == "__main__": | |
svm_baseline() | |