In [1]:
import numpy as np


#加载数据
def load_data():
    #这个数据集,只有数字0,1,2,3,一共800行784列
    with open('mnist_cut.csv') as fr:
        lines = fr.readlines()

    x = np.empty((len(lines), 784), dtype=float)
    y = np.empty(len(lines), dtype=int)

    for i in range(len(lines)):
        line = lines[i].strip().split(',')
        x[i] = line[1:]
        y[i] = line[0]

    #归一化
    x /= 255

    return x, y


x, y = load_data()
x[:5], y[:5], x.shape, y.shape

(array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]),
 array([2, 1, 0, 1, 0]),
 (800, 784),
 (800,))

In [2]:
#查看各类的数量
for i in range(4):
    print(i, (y == [i]).sum())

0 168
1 225
2 209
3 198


In [3]:
N, M = x.shape

#这里训练的是ovr的多分类,一共4个二分类器
params = [0, 1, 2, 3]
ws = {}
bs = {}
for i in params:
    ws[i] = np.empty(M)
    ws[i].fill(1 / M)
    bs[i] = 0.0

In [4]:
#预测函数,显然,他一次只能做一次二分类判断
def predict(params, x):
    #print('predict,params=', params)
    z = ws[params].dot(x) + bs[params]
    return 1 / (1 + np.exp(-z))


predict(2, x[0])

0.5360144532148693

In [5]:
#这个log函数是为了避免log0.
def log(p):
    if p < 1e-20:
        p = 1e-20
    return np.log(p)


#loss函数
def get_loss(param):
    loss = 0
    for i in range(N):
        p = predict(param, x[i])
        #等于param的类是0，否则类是1
        d = 0 if y[i] == param else 1
        loss += d * log(p) + (1 - d) * log(1 - p)

    return loss


get_loss(2)

-537.0785022432013

In [6]:
#暴力求梯度法
#同样的,因为这个程序中实际是4个二分类器,所以这里的梯度,也要指定是哪一个的
def get_gradient(param):
    global w
    global b
    upsilon = 1e-5

    gradient_w = np.empty(M)

    for i in range(M):
        l1 = get_loss(param)
        ws[param][i] += upsilon
        l2 = get_loss(param)
        ws[param][i] -= upsilon
        gradient_w[i] = (l2 - l1) / upsilon

    l1 = get_loss(param)
    bs[param] += upsilon
    l2 = get_loss(param)
    bs[param] -= upsilon
    gradient_b = (l2 - l1) / upsilon

    return gradient_w, gradient_b


get_gradient(2)[0][:100]

array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  

In [7]:
#训练,因为是4个分类器,显然要训练4次
for param in params:
    for i in range(50):
        g_w, g_b = get_gradient(param)
        ws[param] += g_w * 1e-3
        bs[param] += g_b * 1e-3

        if i % 10 == 0:
            print(param, i, get_loss(param))

0 0 -632.6963241739369
0 10 -38.981318928687415
0 20 -28.588793998481414
0 30 -23.014555327394415
0 40 -19.404857433677375
1 0 -580.948153246214
1 10 -43.96444525860368
1 20 -33.85263049684668
1 30 -28.010717257119936
1 40 -24.101525607073505
2 0 -813.0594134266456
2 10 -116.88068344288085
2 20 -101.37078871602604
2 30 -91.96638887264375
2 40 -85.08306111384184
3 0 -849.5871861168231
3 10 -82.83738995425851
3 20 -68.55090861007622
3 30 -60.03544809177709
3 40 -54.05138493660565


In [8]:
#最终结果由4个分类器投票表决，如果出现平局可以考虑每个分类器的置信度
def vote(x):
    for param in params:
        p = predict(param, x)
        if p < 0.5:
            return param
    return 0


vote(x[0]), y[0]

(2, 2)

In [9]:
#测试
correct = 0
for i in range(N):
    p = vote(x[i])
    if p == y[i]:
        correct += 1

print(correct / N)

0.96
