In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

class FCN(nn.Module):
    def __init__(self, num_classes=10):
        super(FCN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        # GAP前卷积层，输出通道数应与类别数(10个数字)相同
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=num_classes, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        #x = self.dropout(x)
        x = F.relu(self.conv2(x))
        #x = self.dropout(x)
        x = F.relu(self.conv3(x))
        #x = self.dropout(x)
        x = F.relu(self.conv4(x))
        #x = self.dropout(x)
        x = self.conv5(x)  # 最后一层卷积不使用ReLU激活
        
        # 使用全局平均池化
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))  # 输出形状变为(B, C, 1, 1)
        
        # 将四维张量展平为二维，以便进行损失计算
        x = x.view(x.size(0), -1)
        return x

# 定义训练函数
def train(model, device, data_loader, optimizer, criterion, epoch):
    watch_batch_size = 100
    model.train()
    avg_loss = 0.0
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()
        if batch_idx % watch_batch_size == (watch_batch_size - 1):  # 每100个batch打印一次
            avg_loss =  avg_loss / watch_batch_size
            print(f'Epoch {epoch + 1}, Batch [{batch_idx + 1}/{len(data_loader)}], AvgLoss: {avg_loss:.4f}')
            avg_loss = 0.0

# 定义验证函数
def evaluate(model, device, data_loader, criterion):
    data_len = 0

    model.eval()
    loss = 0.0
    correct_count = 0.0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss += criterion(output, target).item()  # 累积损失
            data_len += 1
            pred = output.argmax(dim=1, keepdim=True)
            correct_count += pred.eq(target.view_as(pred)).sum().item()
    loss = loss / data_len
    accuracy = correct_count / len(data_loader.dataset)
    return accuracy, loss

def main():
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')

    # 实例化模型
    model = FCN().to(device)
    checkpoint = torch.load('mnist_fcn.pth')
    model.load_state_dict(checkpoint)

    # 加载MNIST数据集
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    validate_dataset = datasets.MNIST('./data', train=False, transform=transform)
    validate_loader = DataLoader(validate_dataset, batch_size=64, shuffle=True)

    # 初始化模型、损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # 训练和验证循环
    num_epochs = 2
    for epoch in range(num_epochs):
        train(model, device, train_loader, optimizer, criterion, epoch)
        train_accuracy, train_loss = evaluate(model, device, train_loader, criterion)
        validate_accuracy, validate_loss = evaluate(model, device, validate_loader, criterion)
        print(f'Epoch {epoch + 1}, train_accuracy = {train_accuracy:.4f}, validate_accuracy = {validate_accuracy:.4f}')
        print(f'Epoch {epoch + 1}, train_loss = {train_loss:.4f}, validate_loss = {validate_loss:.4f}')

    # 保存训练模型和参数
    pth_file_path = 'mnist_fcn.pth'
    torch.save(model.state_dict(), pth_file_path)
    input_shape = (1, 1, 28, 28)  # MNIST图像为28x28像素，单通道
    dummy_input = torch.randn(input_shape).to(device)  # 创建一个虚拟输入张量
    onnx_file_path = "mnist_fcn.onnx"
    torch.onnx.export(model, dummy_input, onnx_file_path, export_params=True, opset_version=17, do_constant_folding=True)

    print(f"Train finished, Export Model to {pth_file_path} and {onnx_file_path}")

if __name__ == "__main__":
    main()

Using device: cpu
Epoch 1, Batch [100/938], AvgLoss: 1.8211
Epoch 1, Batch [200/938], AvgLoss: 0.9904
Epoch 1, Batch [300/938], AvgLoss: 0.5903
Epoch 1, Batch [400/938], AvgLoss: 0.4752
Epoch 1, Batch [500/938], AvgLoss: 0.3605
Epoch 1, Batch [600/938], AvgLoss: 0.2953
Epoch 1, Batch [700/938], AvgLoss: 0.2678
Epoch 1, Batch [800/938], AvgLoss: 0.2176
Epoch 1, Batch [900/938], AvgLoss: 0.2156
Epoch 1, train_accuracy = 0.9429, validate_accuracy = 0.9428
Epoch 1, train_loss = 0.1894, validate_loss = 0.1819
Epoch 2, Batch [100/938], AvgLoss: 0.1883
Epoch 2, Batch [200/938], AvgLoss: 0.1713
Epoch 2, Batch [300/938], AvgLoss: 0.1549
Epoch 2, Batch [400/938], AvgLoss: 0.1241
Epoch 2, Batch [500/938], AvgLoss: 0.1644
Epoch 2, Batch [600/938], AvgLoss: 0.1499
Epoch 2, Batch [700/938], AvgLoss: 0.1090
Epoch 2, Batch [800/938], AvgLoss: 0.1249
Epoch 2, Batch [900/938], AvgLoss: 0.1154
Epoch 2, train_accuracy = 0.9615, validate_accuracy = 0.9607
Epoch 2, train_loss = 0.1316, validate_loss = 0.122