In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
import matplotlib.pyplot as plt

# 设置随机种子
torch.manual_seed(42)



<torch._C.Generator at 0x1a03d659930>

In [8]:
# 1. 定义数据预处理
transform = transforms.Compose([
    transforms.Grayscale(),  # 确保图像是灰度图
    transforms.Resize((28, 28)),  # 调整到28x28
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 3. 加载MNIST数据集（用于训练/测试）
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, 
                                         transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, 
                                        transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)


In [9]:
# 2. 定义CNN模型（与之前相同）
class DigitCNN(nn.Module):
    def __init__(self):
        super(DigitCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [10]:
# 4. 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DigitCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
model = DigitCNN().to(device)
model.load_state_dict(torch.load('digit_cnn.pth', map_location=device))

  model.load_state_dict(torch.load('digit_cnn.pth', map_location=device))


<All keys matched successfully>

In [11]:
# 6. 测试模型（测试集）
def test_model():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# 5. 训练模型并保存
def train_model(num_epochs=5, save_path='digit_cnn.pth'):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
        running_loss = 0.0
        print(f'Accuracy after epoch {epoch+1}: {test_model():.2f}%')
    # 保存模型
    torch.save(model.state_dict(), save_path)
    print(f'Model saved to {save_path}')



# 7. 测试单张图片
def predict_image(image_path, model_path='digit_cnn.pth'):
    # 加载模型
    model = DigitCNN().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    # 加载和预处理图片
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0).to(device)  # 添加batch维度

    # 预测
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
        prediction = predicted.item()

    # 显示图片和预测结果
    plt.imshow(image.squeeze().cpu().numpy(), cmap='gray')
    plt.title(f'Predicted Digit: {prediction}')
    plt.axis('off')
    plt.show()
    print(f'Predicted Digit: {prediction}')



In [12]:
print("Training started...")
train_model(num_epochs=5, save_path='digit_cnn.pth')
print("\nTesting started...")
test_model()



Training started...
Epoch [1/5], Loss: 0.0401
Accuracy after epoch 1: 99.26%
Epoch [2/5], Loss: 0.0045
Accuracy after epoch 2: 99.26%
Epoch [3/5], Loss: 0.0045
Accuracy after epoch 3: 99.26%
Epoch [4/5], Loss: 0.0045
Accuracy after epoch 4: 99.26%
Epoch [5/5], Loss: 0.0045
Accuracy after epoch 5: 99.26%
Model saved to digit_cnn.pth

Testing started...


99.26