In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
import os
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import numpy as np
import torch.nn.functional as F

# 定义CNN模型
class CNNNet(nn.Module):
    def __init__(self):
        super(CNNNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2)
        self.fc1 = nn.Linear(7*7*64, 1024)
        self.fc2 = nn.Linear(1024, 10)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 7*7*64)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class FCN(nn.Module):
    def __init__(self):
        super(FCN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 10, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.conv4(x)
        x = torch.mean(x, dim=(2, 3))  # Global average pooling
        return x



# 定义训练函数
def train(model, device, writer_epoch:SummaryWriter, data_loader, optimizer, criterion, epoch):
    
    watch_batch_size = 100
    model.train()
    avg_loss = 0.0
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()
        if batch_idx % watch_batch_size == (watch_batch_size - 1):  # 每100个batch打印一次
            avg_loss =  avg_loss / watch_batch_size
            writer_epoch.add_scalar('training loss', avg_loss, batch_idx + 1)
            print(f'Epoch {epoch + 1}, Batch [{batch_idx + 1}/{len(data_loader)}], AvgLoss: {avg_loss:.4f}')
            avg_loss = 0.0

# 定义验证函数
def evaluate(model, device, data_loader, criterion):
    data_len = 0

    model.eval()
    loss = 0.0
    correct_count = 0.0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss += criterion(output, target).item()  # 累积损失
            data_len += 1
            pred = output.argmax(dim=1, keepdim=True)
            correct_count += pred.eq(target.view_as(pred)).sum().item()
    loss = loss / data_len
    accuracy = correct_count / len(data_loader.dataset)
    return accuracy, loss

def main():
    model_name = 'FCN'
    # model_name = 'CNNNet'
    
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')

    # 加载MNIST数据集
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    validate_dataset = datasets.MNIST('./data', train=False, transform=transform)
    validate_loader = DataLoader(validate_dataset, batch_size=64, shuffle=True)

    # 初始化模型、损失函数和优化器
    # 初始化模型、损失函数和优化器
    if model_name == 'CNNNet':
        model = CNNNet().to(device)
        print('Using CNNNet')
    elif model_name == 'FCN':
        model = FCN().to(device)
        print('Using FCN')
    else:
        raise ValueError(f"Model {model_name} is not recognized.")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    
    # 获取当前时间戳
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    log_dir = f"runs/{model_name}_experiment_{timestamp}"
    os.makedirs(log_dir, exist_ok=True)
    
    # 添加模型结构到TensorBoard
    writer_model = SummaryWriter(log_dir)
    data_iter = iter(train_loader)
    images, _ = next(data_iter)
    writer_model.add_graph(model, images.to(device))
    
    # 训练和验证循环
    num_epochs = 1
    for epoch in range(num_epochs):
        # 初始化writer_epoch
        log_dir = f"runs/{model_name}_experiment_{timestamp}/epoch_{epoch}"
        writer_epoch = SummaryWriter(log_dir)
        
        epoch_start_time = time.time()
        
        train(model, device, writer_epoch, train_loader, optimizer, criterion, epoch)

        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - epoch_start_time
        writer_epoch.add_scalar('epoch_duration (seconds)', epoch_duration, epoch)
         
        train_accuracy, train_loss = evaluate(model, device, train_loader, criterion)
        validate_accuracy, validate_loss = evaluate(model, device, validate_loader, criterion)
        
        print(f'Epoch {epoch + 1}, train_accuracy = {train_accuracy:.4f}, validate_accuracy = {validate_accuracy:.4f}')
        print(f'Epoch {epoch + 1}, train_loss = {train_loss:.4f}, validate_loss = {validate_loss:.4f}')

    
        # 生成Grad-CAM heatmap（对一个图像）
        data, target = next(iter(validate_loader))
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)

        target_layers = [model.conv2, model.conv1]
        cam = GradCAM(model=model, target_layers=target_layers)
        grayscale_cam = cam(input_tensor=data)[0, :]
    
        # 转换成图像
        rgb_img = data[0].cpu().squeeze().numpy() # 现在只取一个batch的第一张图演示
        rgb_img = np.repeat(rgb_img[:, :, np.newaxis], 3, axis=2)
        rgb_img = (rgb_img - np.min(rgb_img)) / (np.max(rgb_img) - np.min(rgb_img))  # Normalize to 0-1
        # print(grayscale_cam.shape)
        # print(rgb_img.shape)
        visualization = show_cam_on_image(rgb_img, grayscale_cam)
    
        writer_epoch.add_image(f'Input Image Epoch {epoch}', rgb_img, epoch, dataformats='HWC')
        writer_epoch.add_image(f'Overlay Epoch {epoch}', visualization, epoch, dataformats='HWC')
        
    # 保存训练模型和参数
    torch.save(model.state_dict(), 'mnist_cnn_pytorch.pth')
    print(f"Model save to mnist_cnn_pytorch")
    
    input_shape = (1, 1, 28, 28)  # MNIST图像为28x28像素，单通道
    dummy_input = torch.randn(input_shape).to(device)  # 创建一个虚拟输入张量
    onnx_file_path = "mnist_cnn_onnx.onnx"
    torch.onnx.export(model, dummy_input, onnx_file_path, export_params=True, opset_version=17, do_constant_folding=True)

    print(f"Train finished, Export Model to {onnx_file_path}")

    writer_epoch.close()
    
if __name__ == "__main__":
    main()

Using device: cuda
Epoch 1, Batch [100/938], AvgLoss: 1.0551
Epoch 1, Batch [200/938], AvgLoss: 0.2589
Epoch 1, Batch [300/938], AvgLoss: 0.1918
Epoch 1, Batch [400/938], AvgLoss: 0.1480
Epoch 1, Batch [500/938], AvgLoss: 0.1134
Epoch 1, Batch [600/938], AvgLoss: 0.1090
Epoch 1, Batch [700/938], AvgLoss: 0.0974
Epoch 1, Batch [800/938], AvgLoss: 0.0850
Epoch 1, Batch [900/938], AvgLoss: 0.0784
Epoch 1, train_accuracy = 0.9812, validate_accuracy = 0.9808
Epoch 1, train_loss = 0.0632, validate_loss = 0.0601
Model save to mnist_cnn_pytorch
Train finished, Export Model to mnist_cnn_onnx.onnx
