#### SVM算法应用于MNIST数据集

In [6]:
# 下载MNIST数据集
import torch
# transforms对图像进行预处理和数据增强
#   归一化(Normalization)
#   调整大小(Resize)
#   数据增强(Augmentation，如翻转、裁剪、旋转等)
#   转化为张量(Convert to Tensor)
from torchvision import datasets, transforms

# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 下载训练集和测试集
trainset = datasets.MNIST(root='./data', train=True, download=False, transform=transform)
testset = datasets.MNIST(root='./data', train=False, download=False, transform=transform)

In [7]:
# 加载数据的参数
train_batch_size = 64
test_batch_size = 1000
# 加载训练数据和测试数据
trainset1 = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True)
testset1 = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False)
# 打印数据集的形状
print(f"训练集大小: {len(trainset1.dataset)}")
print(f"测试集大小: {len(testset1.dataset)}")

训练集大小: 60000
测试集大小: 10000


In [11]:
# 导入库
from sklearn.svm import SVC
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
import numpy as np

# 将数据转换为numpy数组
X_train = trainset.data.numpy().reshape(-1,28*28)
y_train = trainset.targets.numpy()
X_test = testset.data.numpy().reshape(-1,28*28)
y_test = testset.targets.numpy()

# 定义PCA模型,对特征向量降维
pca = PCA(n_components=50)
# 训练PCA模型
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)

# 定义SVM模型
svm_clf = SVC(kernel='rbf', C=10, gamma='scale')
# 训练SVM模型
svm_clf.fit(X_train_pca, y_train)
# 预测测试集
y_pred = svm_clf.predict(X_test_pca)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"准确率: {accuracy:.5f}")
print(classification_report(y_test, y_pred))


准确率: 0.98680
              precision    recall  f1-score   support

           0       0.99      0.99      0.99       980
           1       0.99      1.00      0.99      1135
           2       0.98      0.98      0.98      1032
           3       0.98      0.99      0.99      1010
           4       0.99      0.98      0.99       982
           5       0.99      0.99      0.99       892
           6       0.99      0.99      0.99       958
           7       0.99      0.98      0.98      1028
           8       0.98      0.99      0.99       974
           9       0.99      0.98      0.98      1009

    accuracy                           0.99     10000
   macro avg       0.99      0.99      0.99     10000
weighted avg       0.99      0.99      0.99     10000

