# Mnist数据集

In [8]:
import pandas as pd
import numpy as np
import csv
import scipy.special
import matplotlib.pyplot as plt

In [9]:
class neuralNetwork:
    def __init__(self,inputnodes,hidennodes,outputnodes,learningrate):
        self.inodes = inputnodes
        self.hnodes = hidennodes
        self.onodes = outputnodes
        self.lr = learningrate

        self.wih = np.random.normal(0.0,pow(self.hnodes,-0.5),(self.hnodes,self.inodes))
        self.who = np.random.normal(0.0,pow(self.onodes,-0.5),(self.onodes,self.hnodes))

        self.activation_function = lambda x:scipy.special.expit(x)


    def train(self,inputs_list,target_list):
        inputs = np.array(inputs_list, ndmin=2).T
        targets = np.array(target_list, ndmin=2).T

        hidden_inputs = np.dot(self.wih, inputs)
        hidden_outputs = self.activation_function(hidden_inputs)

        final_inputs = np.dot(self.who, hidden_outputs)
        final_outputs = self.activation_function(final_inputs)

        output_errors = targets - final_outputs
        hidden_errors = np.dot(self.who.T,output_errors)

        self.who += self.lr * np.dot((output_errors * final_outputs * (1.0-final_outputs)),np.transpose(hidden_outputs))
        self.wih += self.lr * np.dot((hidden_errors * hidden_outputs * (1.0-hidden_outputs)),np.transpose(inputs))

    def query(self, inputs_list):
        inputs = np.array(inputs_list,ndmin=2).T
        hidden_inputs = np.dot(self.wih,inputs)
        hidden_outputs = self.activation_function(hidden_inputs)

        final_inputs = np.dot(self.who,hidden_outputs)
        final_outputs = self.activation_function(final_inputs)
        return final_outputs

In [10]:
input_nodes = 784
hidden_nodes = 200
output_nodes = 10
learning_rate = 0.2
n = neuralNetwork(input_nodes,hidden_nodes,output_nodes,learning_rate)

In [11]:
with open("mnist_train.csv",'r') as data_file:
    data_list = data_file.readlines()

In [12]:
epochs = 5
for i in range(epochs):
    for record in data_list:
        all_values = record.split(',')
        inputs = (np.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
        targets = np.zeros(output_nodes) + 0.01
        targets[int(all_values[0])] = 0.99
        n.train(inputs, targets)

In [23]:
with open("mnist_test.csv",'r') as test_data_file:
    test_data_list = test_data_file.readlines()

In [28]:
scorecard=[]
correct_values=[]
predictions=[]

for i in test_data_list:
    all_values = i.split(',')
    correct_value = int(all_values[0])
    correct_values.append(correct_value)
    
    inputs = (np.asfarray(all_values[1:])/255.0 * 0.99) + 0.01
    outputs = n.query(inputs)
    prediction = np.argmax(outputs)
    predictions.append(prediction)

    if (prediction == correct_value):
        scorecard.append(1)
    else:
        scorecard.append(0)

In [29]:
print("正确数据：",correct_values[:10])
print("预测数据：",predictions[:10])

正确数据： [7, 2, 1, 0, 4, 1, 4, 9, 5, 9]
预测数据： [7, 2, 1, 0, 4, 1, 4, 9, 5, 9]


In [30]:
errors=np.nonzero(np.array(correct_values)-np.array(predictions))
print(errors)

(array([  38,  115,  149,  233,  247,  259,  266,  290,  321,  445,  447,
        448,  495,  543,  582,  583,  591,  605,  610,  629,  659,  684,
        691,  717,  781,  844,  877,  882,  939,  947,  950,  965, 1003,
       1014, 1032, 1039, 1044, 1112, 1156, 1173, 1178, 1181, 1182, 1194,
       1226, 1232, 1247, 1251, 1260, 1289, 1299, 1319, 1325, 1328, 1337,
       1393, 1414, 1415, 1425, 1444, 1466, 1494, 1500, 1522, 1530, 1543,
       1546, 1549, 1551, 1553, 1554, 1581, 1609, 1644, 1678, 1681, 1686,
       1702, 1709, 1717, 1718, 1721, 1754, 1761, 1790, 1871, 1878, 1901,
       1940, 1941, 1952, 1973, 2024, 2044, 2053, 2070, 2093, 2098, 2125,
       2130, 2135, 2182, 2186, 2189, 2215, 2272, 2293, 2326, 2369, 2371,
       2380, 2387, 2406, 2414, 2425, 2433, 2447, 2454, 2488, 2496, 2525,
       2526, 2542, 2573, 2597, 2607, 2648, 2654, 2721, 2730, 2810, 2896,
       2915, 2927, 2939, 2970, 3030, 3060, 3062, 3073, 3117, 3157, 3166,
       3218, 3282, 3289, 3326, 3330, 3333, 3369, 3

In [31]:
p_1 = 1-len(errors[0])/len(correct_values)
print("accuracy：",p_1)

accuracy： 0.9657
