# 数据集：Mnist
- 训练集数量：60000
- 测试集数量：10000

In [1]:
import numpy as np
import math
import random
from tqdm import tqdm

In [2]:
# load data
def load_data(fileName):
    """
    加载数据
    """
    
    dataArr = []
    labelArr = []
    
    fr = open(fileName, 'r')
    
    for line in tqdm(fr.readlines()):
        currentLine = line.strip().split(',')
        dataArr.append([int(num) /255 for num in currentLine[1: ]])
        
        # 0 -> 1 其余 -> -1
        if int(currentLine[0]) == 0:
            labelArr.append(1)
        else:
            labelArr.append(-1)
            
    return dataArr, labelArr

In [6]:
# SVM
class SVM:
    """
    SVM 类
    """
    
    def __init__(self, trainDataList, trainLabelList, sigma = 10, C = 200, tol = 0.001):
        """
        SVM相关参数初始化
        sigma : 高斯核中的分母sigma
        C : 软间隔中的惩罚参数
        tol ：松弛变量
        """
        
        self.trainDataMat = np.mat(trainDataList)
        self.trainLabelMat = np.mat(trainLabelList).T
        
        self.m, self.n = np.shape(self.trainDataMat)  # m 样本数量 n 特征个数
        self.sigma = sigma
        self.C = C
        self.tol = tol
        
        self.k = self.calcKernel()  # 核函数
        self.b = 0
        self.alpha = [0] * self.trainDataMat.shape[0]
        self.E = [0 * self.trainLabelMat[i, 0] for i in range(self.trainLabelMat.shape[0])]  #SMO 运算过程中的Ei
        self.supportVecIndex = []
        
        
    def calcKernel(self):
        """
        计算核函数 高斯核函数
        """
        
        # 初始化高斯矩阵 k[i][j] = Xi * Xj
        k = [[0 for i in range(self.m)] for j in range(self.m)]
        
        for i in range(self.m):
            if i % 100 == 0:
                print("construct the kernel: ", i, self.m)
            
            X = self.trainDataMat[i, :]
            #
            for j in range(i, self.m):
                Z = self.trainDataMat[j, :]
                result = (X - Z) * (X - Z).T
                result = np.exp(-1 * result / (2 * self.sigma ** 2))
                k[i][j] = result
                k[j][i] = result
                
        return k
    
    
    def isSatifyKKT(self, i):
        """
        查看第i个alpha是否满足KKT条件
        """
        
        gxi = self.calc_gxi(i)
        yi = self.trainLabelMat[i]
        
        if (math.fabs(self.alpha[i]) < self.tol) and (yi * gxi >= 1):
            return True
        elif (math.fabs(self.alpha[i] - self.C) < self.tol) and (yi * gxi <= 1):
            return True
        elif (self.alpha[i] > -self.tol) and (self.alpha[i] < (self.C + self.tol)) and (math.fabs(yi * gxi - 1) < self.tol):
            return True
        
        return False
    
    def calc_gxi(self, i):
        """
        计算g(xi)
        """
        
        gxi = 0
        index = [i for i, alpha in enumerate(self.alpha) if alpha != 0]
        
        for j in index:
            gxi += self.alpha[j] * self.trainLabelMat[j] * self.k[j][i]
        
        gxi += self.b
        
        return gxi
    
    def calc_Ei(self, i):
        """
        计算Ei
        """
        gxi = self.calc_gxi(i)
        return gxi - self.trainLabelMat[i]
        
    def getAlphaJ(self, E1, i):
        """
        SMO 选择第二个变量
        """
        
        E2 = 0
        
        maxE1_E2 = -1
        maxIndex = -1
        
        nozeroE = [i for i, Ei in enumerate(self.E) if Ei != 0]
        
        for j in nozeroE:
            E2_tmp = self.calc_Ei(j)
            if math.fabs(E1 - E2_tmp) > maxE1_E2:
                maxE1_E2 = math.fabs(E1 - E2_tmp)
                E2 = E2_tmp
                maxIndex = j
        
        if maxIndex == -1:
            maxIndex = i
            while maxIndex == i:
                maxIndex = int(random.uniform(0, self.m))
            E2 = self.calc_Ei(maxIndex)
            
        return E2, maxIndex
    
    def train(self, max_iter = 100):
        """
        训练
        """
        iterStep = 0
        paramterChanged = 1
        
        while (iterStep < max_iter) and (paramterChanged > 0):
            print('iter: %d: %d' % (iterStep, max_iter))
            iterStep += 1
            paramterChanged = 0
            
            for i in range(self.m):
                if self.isSatifyKKT(i) == False:
                    E1 = self.calc_Ei(i)
                    
                    E2, j = self.getAlphaJ(E1, i)
                    
                    y1 = self.trainLabelMat[i]
                    y2 = self.trainLabelMat[j]
                    
                    alpha0ld_1 = self.alpha[i]
                    alpha0ld_2 = self.alpha[j]
                    
                    if y1 != y2:
                        L = max(0, alpha0ld_2 - alpha0ld_1)
                        H = min(self.C, self.C + alpha0ld_2 - alpha0ld_1)
                    else:
                        L = max(0, alpha0ld_2 + alpha0ld_1 - self.C)
                        H = min(self.C, alpha0ld_2 + alpha0ld_1)
                        
                    if L == H:
                        continue
                        
                    k11 = self.k[i][i]
                    k12 = self.k[i][j]
                    k21 = self.k[j][i]
                    k22 = self.k[j][j]
                    
                    alphaNew_2 = alpha0ld_2 + y2 * (E1 - E2) / (k11 + k22 - 2 * k12)
                    
                    if alphaNew_2 < L :
                        alphaNew_2 = L
                    elif alphaNew_2 > H:
                        alphaNew_2 = H
                        
                    alphaNew_1 = alpha0ld_1 + y1 * y2 * (alpha0ld_2 - alphaNew_2)
                    
                    b1New = -1 * E1 - y1 * k11 * (alphaNew_1 - alpha0ld_1) - y2 * k21 * (alphaNew_2 - alpha0ld_2) + self.b
                    b2New = -1 * E2 - y1 * k12 * (alphaNew_1 - alpha0ld_1) - y2 * k22 * (alphaNew_2 - alpha0ld_2) + self.b
                    
                    if (alphaNew_1 > 0) and (alphaNew_1 < self.C):
                        bNew = b1New
                    elif (alphaNew_2 > 0) and (alphaNew_2 < self.C):
                        bNew = b2New
                    else:
                        bNew = (b1New + b2New) / 2
                    
                    # Update
                    self.alpha[i] = alphaNew_1
                    self.alpha[j] = alphaNew_2
                    self.b = bNew
                    
                    self.E[i] = self.calc_Ei(i)
                    self.E[j] = self.calc_Ei(j)
                    
                    if math.fabs(alphaNew_2 - alpha0ld_2) >= 0.00001:
                        paramterChanged += 1
                        
                print("iter: %d i : %d, pairs chaged %d " % (iterStep, i, paramterChanged))
            
        for i in range(self.m):
            if self.alpha[i] > 0:
                self.supportVecIndex.append(i)
                
    def calc_singleKernal(self, x1, x2):
        """
        单独计算核函数
        """
        result = (x1 - x2) * (x1 - x2).T
        result = np.exp(-1 * result / (2 * self.sigma ** 2))
        
        return np.exp(result)
    
    def predict(self, x):
        """
        预测
        """
        result = 0
        for i in self.supportVecIndex:
            tmp = self.calc_singleKernal(self.trainDataMat[i, :], np.mat(x))
            result += self.alpha[i] * self.trainLabelMat[i] * tmp
            
        result += self.b
        return np.sign(result)
    
    def test(self, testDataList, testLabelList):
        """
        测试
        """
        
        error_count = 0
        y_pred = []
        for i in range(len(testDataList)):
            result = self.predict(testDataList[i])
            y_pred.append(result)
            if result != testLabelList[i]:
                error_count += 1
        
        acc = 1 - error_count / len(testDataList)
        return acc, y_pred
    
    

In [4]:
trainDataList, trainLabelList = load_data('./mnist/mnist_train.csv')
testDataList, testLabelList = load_data('./mnist/mnist_test.csv')

100%|███████████████████████████████████| 60000/60000 [00:09<00:00, 6377.54it/s]
100%|███████████████████████████████████| 10000/10000 [00:01<00:00, 6247.87it/s]


In [7]:
svm = SVM(trainDataList[: 1000], trainLabelList[: 1000], 10, 200, 0.001)
svm.train()

construct the kernel:  0 1000
construct the kernel:  100 1000
construct the kernel:  200 1000
construct the kernel:  300 1000
construct the kernel:  400 1000
construct the kernel:  500 1000
construct the kernel:  600 1000
construct the kernel:  700 1000
construct the kernel:  800 1000
construct the kernel:  900 1000
iter: 0: 100
iter: 1 i : 0, pairs chaged 1 
iter: 1 i : 1, pairs chaged 2 
iter: 1 i : 2, pairs chaged 3 
iter: 1 i : 3, pairs chaged 4 
iter: 1 i : 4, pairs chaged 5 
iter: 1 i : 5, pairs chaged 6 
iter: 1 i : 6, pairs chaged 6 
iter: 1 i : 7, pairs chaged 7 
iter: 1 i : 8, pairs chaged 7 
iter: 1 i : 9, pairs chaged 7 
iter: 1 i : 10, pairs chaged 8 
iter: 1 i : 11, pairs chaged 9 
iter: 1 i : 14, pairs chaged 9 
iter: 1 i : 17, pairs chaged 9 
iter: 1 i : 18, pairs chaged 9 
iter: 1 i : 19, pairs chaged 9 
iter: 1 i : 20, pairs chaged 9 
iter: 1 i : 21, pairs chaged 10 
iter: 1 i : 22, pairs chaged 10 
iter: 1 i : 23, pairs chaged 10 
iter: 1 i : 24, pairs chaged 11 
ite

iter: 1 i : 416, pairs chaged 35 
iter: 1 i : 417, pairs chaged 35 
iter: 1 i : 418, pairs chaged 35 
iter: 1 i : 419, pairs chaged 35 
iter: 1 i : 420, pairs chaged 35 
iter: 1 i : 422, pairs chaged 35 
iter: 1 i : 423, pairs chaged 35 
iter: 1 i : 424, pairs chaged 35 
iter: 1 i : 427, pairs chaged 35 
iter: 1 i : 428, pairs chaged 35 
iter: 1 i : 429, pairs chaged 35 
iter: 1 i : 431, pairs chaged 35 
iter: 1 i : 434, pairs chaged 35 
iter: 1 i : 435, pairs chaged 35 
iter: 1 i : 437, pairs chaged 35 
iter: 1 i : 438, pairs chaged 35 
iter: 1 i : 440, pairs chaged 35 
iter: 1 i : 441, pairs chaged 35 
iter: 1 i : 442, pairs chaged 35 
iter: 1 i : 443, pairs chaged 35 
iter: 1 i : 444, pairs chaged 35 
iter: 1 i : 447, pairs chaged 35 
iter: 1 i : 450, pairs chaged 35 
iter: 1 i : 454, pairs chaged 35 
iter: 1 i : 455, pairs chaged 35 
iter: 1 i : 458, pairs chaged 35 
iter: 1 i : 459, pairs chaged 35 
iter: 1 i : 462, pairs chaged 35 
iter: 1 i : 463, pairs chaged 35 
iter: 1 i : 46

iter: 1 i : 876, pairs chaged 35 
iter: 1 i : 878, pairs chaged 35 
iter: 1 i : 879, pairs chaged 35 
iter: 1 i : 880, pairs chaged 35 
iter: 1 i : 881, pairs chaged 35 
iter: 1 i : 884, pairs chaged 35 
iter: 1 i : 888, pairs chaged 35 
iter: 1 i : 891, pairs chaged 35 
iter: 1 i : 892, pairs chaged 35 
iter: 1 i : 894, pairs chaged 35 
iter: 1 i : 898, pairs chaged 35 
iter: 1 i : 899, pairs chaged 35 
iter: 1 i : 900, pairs chaged 35 
iter: 1 i : 904, pairs chaged 35 
iter: 1 i : 905, pairs chaged 35 
iter: 1 i : 915, pairs chaged 35 
iter: 1 i : 918, pairs chaged 35 
iter: 1 i : 919, pairs chaged 35 
iter: 1 i : 920, pairs chaged 35 
iter: 1 i : 921, pairs chaged 35 
iter: 1 i : 922, pairs chaged 35 
iter: 1 i : 932, pairs chaged 35 
iter: 1 i : 933, pairs chaged 35 
iter: 1 i : 935, pairs chaged 35 
iter: 1 i : 936, pairs chaged 35 
iter: 1 i : 939, pairs chaged 35 
iter: 1 i : 940, pairs chaged 35 
iter: 1 i : 941, pairs chaged 35 
iter: 1 i : 945, pairs chaged 35 
iter: 1 i : 94

iter: 2 i : 303, pairs chaged 2 
iter: 2 i : 304, pairs chaged 2 
iter: 2 i : 305, pairs chaged 2 
iter: 2 i : 307, pairs chaged 2 
iter: 2 i : 309, pairs chaged 2 
iter: 2 i : 310, pairs chaged 2 
iter: 2 i : 312, pairs chaged 2 
iter: 2 i : 313, pairs chaged 2 
iter: 2 i : 315, pairs chaged 2 
iter: 2 i : 318, pairs chaged 2 
iter: 2 i : 319, pairs chaged 2 
iter: 2 i : 321, pairs chaged 2 
iter: 2 i : 322, pairs chaged 2 
iter: 2 i : 326, pairs chaged 2 
iter: 2 i : 327, pairs chaged 2 
iter: 2 i : 334, pairs chaged 2 
iter: 2 i : 336, pairs chaged 2 
iter: 2 i : 338, pairs chaged 2 
iter: 2 i : 341, pairs chaged 2 
iter: 2 i : 342, pairs chaged 2 
iter: 2 i : 343, pairs chaged 2 
iter: 2 i : 344, pairs chaged 2 
iter: 2 i : 345, pairs chaged 2 
iter: 2 i : 346, pairs chaged 2 
iter: 2 i : 348, pairs chaged 2 
iter: 2 i : 350, pairs chaged 2 
iter: 2 i : 351, pairs chaged 2 
iter: 2 i : 353, pairs chaged 2 
iter: 2 i : 354, pairs chaged 2 
iter: 2 i : 355, pairs chaged 2 
iter: 2 i 

iter: 2 i : 754, pairs chaged 2 
iter: 2 i : 761, pairs chaged 2 
iter: 2 i : 762, pairs chaged 2 
iter: 2 i : 765, pairs chaged 2 
iter: 2 i : 766, pairs chaged 2 
iter: 2 i : 773, pairs chaged 2 
iter: 2 i : 775, pairs chaged 2 
iter: 2 i : 778, pairs chaged 2 
iter: 2 i : 779, pairs chaged 2 
iter: 2 i : 780, pairs chaged 2 
iter: 2 i : 781, pairs chaged 2 
iter: 2 i : 783, pairs chaged 2 
iter: 2 i : 784, pairs chaged 2 
iter: 2 i : 786, pairs chaged 2 
iter: 2 i : 787, pairs chaged 2 
iter: 2 i : 788, pairs chaged 2 
iter: 2 i : 795, pairs chaged 2 
iter: 2 i : 796, pairs chaged 2 
iter: 2 i : 797, pairs chaged 2 
iter: 2 i : 798, pairs chaged 2 
iter: 2 i : 800, pairs chaged 2 
iter: 2 i : 802, pairs chaged 2 
iter: 2 i : 803, pairs chaged 2 
iter: 2 i : 806, pairs chaged 2 
iter: 2 i : 808, pairs chaged 2 
iter: 2 i : 809, pairs chaged 2 
iter: 2 i : 810, pairs chaged 2 
iter: 2 i : 811, pairs chaged 2 
iter: 2 i : 812, pairs chaged 2 
iter: 2 i : 817, pairs chaged 2 
iter: 2 i 

iter: 3 i : 217, pairs chaged 0 
iter: 3 i : 218, pairs chaged 0 
iter: 3 i : 219, pairs chaged 0 
iter: 3 i : 221, pairs chaged 0 
iter: 3 i : 223, pairs chaged 0 
iter: 3 i : 224, pairs chaged 0 
iter: 3 i : 225, pairs chaged 0 
iter: 3 i : 226, pairs chaged 0 
iter: 3 i : 227, pairs chaged 0 
iter: 3 i : 228, pairs chaged 0 
iter: 3 i : 229, pairs chaged 0 
iter: 3 i : 230, pairs chaged 0 
iter: 3 i : 231, pairs chaged 0 
iter: 3 i : 233, pairs chaged 0 
iter: 3 i : 235, pairs chaged 0 
iter: 3 i : 238, pairs chaged 0 
iter: 3 i : 240, pairs chaged 0 
iter: 3 i : 247, pairs chaged 0 
iter: 3 i : 248, pairs chaged 0 
iter: 3 i : 249, pairs chaged 0 
iter: 3 i : 250, pairs chaged 0 
iter: 3 i : 251, pairs chaged 0 
iter: 3 i : 252, pairs chaged 0 
iter: 3 i : 255, pairs chaged 0 
iter: 3 i : 256, pairs chaged 0 
iter: 3 i : 257, pairs chaged 0 
iter: 3 i : 258, pairs chaged 0 
iter: 3 i : 264, pairs chaged 0 
iter: 3 i : 267, pairs chaged 0 
iter: 3 i : 268, pairs chaged 0 
iter: 3 i 

iter: 3 i : 648, pairs chaged 0 
iter: 3 i : 650, pairs chaged 0 
iter: 3 i : 653, pairs chaged 0 
iter: 3 i : 654, pairs chaged 0 
iter: 3 i : 659, pairs chaged 0 
iter: 3 i : 660, pairs chaged 0 
iter: 3 i : 661, pairs chaged 0 
iter: 3 i : 664, pairs chaged 0 
iter: 3 i : 669, pairs chaged 0 
iter: 3 i : 671, pairs chaged 0 
iter: 3 i : 674, pairs chaged 0 
iter: 3 i : 676, pairs chaged 0 
iter: 3 i : 678, pairs chaged 0 
iter: 3 i : 682, pairs chaged 0 
iter: 3 i : 683, pairs chaged 0 
iter: 3 i : 685, pairs chaged 0 
iter: 3 i : 686, pairs chaged 0 
iter: 3 i : 688, pairs chaged 0 
iter: 3 i : 689, pairs chaged 0 
iter: 3 i : 691, pairs chaged 0 
iter: 3 i : 692, pairs chaged 0 
iter: 3 i : 694, pairs chaged 0 
iter: 3 i : 695, pairs chaged 0 
iter: 3 i : 696, pairs chaged 0 
iter: 3 i : 698, pairs chaged 0 
iter: 3 i : 700, pairs chaged 0 
iter: 3 i : 705, pairs chaged 0 
iter: 3 i : 706, pairs chaged 0 
iter: 3 i : 707, pairs chaged 0 
iter: 3 i : 708, pairs chaged 0 
iter: 3 i 

In [8]:
# test
acc, y_pre = svm.test(testDataList[0: 100], testLabelList[0: 100])
print("Accuracy: ", acc)

Accuracy:  0.98


In [10]:
trainLabelList[0: 2]

[-1, 1]

In [11]:
y_pre[:2]

[matrix([[-1.]]), matrix([[-1.]])]