# 数据集：Mnist
- 训练集数量：60000
- 测试集数量：10000
- 10 类 0 - 9

In [1]:
import numpy as np
from tqdm import tqdm


In [2]:
def load_data(fileName):
    """
    加载数据集
    """
    dataList = []
    labelList = []
    
    fr = open(fileName, 'r')
    
    for line in tqdm(fr.readlines()):
        current_Line = line.strip().split(',')
        
        # 0 -> 1 其余 -> 0 二分类
        if int(current_Line[0]) == 0:
            labelList.append(1)
        else:
            labelList.append(0)
            
        dataList.append([int(num) / 255 for num in current_Line[1 : ]])
    
    return dataList, labelList

def predict(w, x):
    """
    预测标签
    """
    wx = np.dot(w, x)
    P1 = np.exp(wx) / (1 + np.exp(wx))
    if P1 >= 0.5:
        return 1
    return 0

def logisticRegression(trainDataList, trainLabelList, max_iter = 100):
    """
    二项逻辑斯蒂回归过程
    """
    for i in range(len(trainDataList)):
        trainDataList[i].append(1)
        
    trainDataList = np.array(trainDataList)
    w = np.zeros(trainDataList.shape[1])
    
    # 学习率
    learning_rate = 0.001
    
    # 梯度下降
    for i in range(max_iter):
        for j in range(trainDataList.shape[0]):
            wx = np.dot(w, trainDataList[j])
            yi = trainLabelList[j]
            xi = trainDataList[j]
            
            # update weight
            w += learning_rate * (xi * yi - (np.exp(wx) * xi) / (1 + np.exp(wx)))
            
    return w
               

In [4]:
trainData, trainLabel = load_data('./mnist/mnist_train.csv')
testData, testLabel = load_data('./mnist/mnist_test.csv')

100%|███████████████████████████████████| 60000/60000 [00:09<00:00, 6074.26it/s]
100%|███████████████████████████████████| 10000/10000 [00:02<00:00, 4297.44it/s]


In [5]:
w = logisticRegression(trainData, trainLabel)

In [7]:
# test
for i in range(len(testData)):
    testData[i].append(1)
    
# 错误数统计
error_count = 0

for i in range(len(testData)):
    y_predict = predict(w, testData[i])
    if y_predict != testLabel[i]:
        error_count += 1

acc = 1 - error_count / len(testData)

print("Accuracy: ", acc)

Accuracy:  0.9922
