In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision import models
import time
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

class CustomResNet50(nn.Module):
    def __init__(self):
        super(CustomResNet50, self).__init__()
        self.resnet50 = models.resnet50(weights=None)
        self.resnet50.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # 修改输入通道数为1
        num_ftrs = self.resnet50.fc.in_features
        self.resnet50.fc = nn.Linear(num_ftrs, 10)  # 修改输出大小为10（适应10个类别）
    
    def forward(self, x):
        x = self.resnet50.conv1(x)
        x = self.resnet50.bn1(x)
        x = self.resnet50.relu(x)
        x = self.resnet50.maxpool(x)

        x = self.resnet50.layer1(x)
        x = self.resnet50.layer2(x)
        x = self.resnet50.layer3(x)
        x = self.resnet50.layer4(x)

        x = self.resnet50.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.resnet50.fc(x)

        return x
        
# CAM通过替换全连接层为GAP层，重新训练得到权重，而Grad-CAM另辟蹊径，用梯度的全局平均来计算权重。事实上，经过严格的数学推导，Grad-CAM与CAM计算出来的权重是等价的。
# 定义训练函数
def train(model, device:torch.device, writer_epoch:SummaryWriter, data_loader:DataLoader, optimizer:optim, criterion, epoch)->None:
    watch_batch_size = 100
    model.train()
    avg_loss = 0.0
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        avg_loss += loss.item()
        if batch_idx % watch_batch_size == (watch_batch_size - 1):  # 每100个batch打印一次
            avg_loss =  avg_loss / watch_batch_size
            writer_epoch.add_scalars('training loss pre batch', {f'avg_loss{epoch+1}': avg_loss}, batch_idx + 1)
            print(f'Epoch {epoch + 1}, Batch [{batch_idx + 1}/{len(data_loader)}], AvgLoss: {avg_loss:.4f}')
            avg_loss = 0.0

# 定义验证函数
def evaluate(model, device:torch.device, data_loader:DataLoader, criterion) ->tuple[float, float]:
    data_len = 0
    model.eval()
    loss = 0.0
    correct_count = 0.0
    first_batch = next(iter(data_loader))[0]
    # print('first_batch.shape', first_batch.shape)
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            # print(target.shape) # 一个64维的行向量，代表整个64张图的batch中每个图的代表的数字
            # print(data.shape)
            output = model(data)
            # print(output.shape) # 64行10列表示这个batch中的64张图对应到每个数字（0-9）的概率。
            loss += criterion(output, target).item()  # 累积损失
            data_len += 1
            pred = output.argmax(dim=1, keepdim=True)
            correct_count += pred.eq(target.view_as(pred)).sum().item()
    loss = loss / data_len
    accuracy = correct_count / len(data_loader.dataset)
    return accuracy, loss, first_batch

def main():
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')
    
    # 设置超参
    model_name = 'CustomResNet50'
    batch_size = 64
    learning_rate = 0.001
    num_epochs = 3
    
    # 加载MNIST数据集
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    print(train_dataset)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    validate_dataset = datasets.MNIST('./data', train=False, transform=transform)
    print(validate_dataset)
    validate_loader = DataLoader(validate_dataset, batch_size=batch_size, shuffle=True)
    
    # 创建自定义ResNet-50模型实例
    model = CustomResNet50().to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr = learning_rate)
    
    # 获取当前时间戳
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    log_dir = f"runs/{model_name}_experiment_{timestamp}"
    os.makedirs(log_dir, exist_ok=True)
    
    # 添加模型结构到TensorBoard
    writer_model = SummaryWriter(log_dir)
    data_iter = iter(train_loader)
    images, _ = next(data_iter)
    print(images.shape)
    writer_model.add_graph(model, images.to(device))
    
    # model.load_state_dict(torch.load('mnist_ResNet18.pth'))
    
    # 训练和验证循环
    for epoch in range(num_epochs):
        # 初始化writer_epoch
        writer_epoch = SummaryWriter(log_dir)
        
        epoch_start_time = time.time()
        
        train(model, device, writer_epoch, train_loader, optimizer, criterion, epoch)

        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - epoch_start_time
        writer_epoch.add_scalar('epoch_duration (seconds)', epoch_duration, epoch+1)

        train_accuracy, train_loss, batch0 = evaluate(model, device, train_loader, criterion)
        validate_accuracy, validate_loss, _ = evaluate(model, device, validate_loader, criterion)
        print(f'train_accuracy = {train_accuracy}%, validate_accuracy = {validate_accuracy}%')
        
        
        writer_epoch.add_scalars('accurancy', {'train accurancy': train_accuracy, 
                                               'validate accurancy':validate_accuracy},global_step=epoch+1)
        writer_epoch.add_scalars('loss', {'train loss': train_loss, 
                                          'validate loss':validate_loss},global_step=epoch+1)

        img_grid = make_grid(batch0, nrow=8, normalize=False, scale_each=False)
        writer_epoch.add_image(f'batch0_imgs{epoch+1}', img_grid, global_step=None)
        
        # 生成Grad-CAM heatmap（对一个图像）
        # target_layers = [model.layer4[-1]]
        target_layers = [model.resnet50.layer4[-1]]
        cam = GradCAM(model=model, target_layers=target_layers)
        mask = cam(input_tensor=images)[0, :] # (H, W, C)
        
        # 转换成图像
        gray_img = images[0].cpu().squeeze().numpy() # 现在只取一个batch的第一张图演示
        print(gray_img.shape)
        rgb_img_hwc = np.repeat(gray_img[:, :, np.newaxis], 3, axis=2) # (H, W, C)
        print(rgb_img_hwc.shape)
        rgb_img_hwc = cv2.normalize(rgb_img_hwc, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
        visualization_c3u8_rgb = show_cam_on_image(rgb_img_hwc, mask, use_rgb=True)
        
        
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        axs[0].imshow(rgb_img_hwc)  # 第一个图的示例数据
        axs[0].set_title('Image')
        axs[1].imshow(mask)  # 第二个图的示例数据
        axs[1].set_title('Heatmap')
        axs[2].imshow(visualization_c3u8_rgb)  # 第三个图的示例数据
        axs[2].set_title('CAMonImage')

        plt.tight_layout()
        plt.show()
        
        
        rgb_img_hwc_c3u8 = cv2.normalize(rgb_img_hwc, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_32F).astype(np.uint8)
        
        img_tensor = torch.from_numpy(rgb_img_hwc_c3u8).permute(2, 0, 1).unsqueeze(0)  # 将形状从 [H, W, C] 调整为 [B, C, H, W]
        visualization_tensor = torch.from_numpy(visualization_c3u8_rgb).permute(2, 0, 1).unsqueeze(0) # 将形状从 [H, W, C] 调整为 [B, C, H, W]
        cam_images_tensor = torch.cat((img_tensor, visualization_tensor), dim=0)
        
        writer_epoch.add_images("cam_images", cam_images_tensor, global_step=epoch+1)
        
    # 保存训练模型和参数
    torch.save(model.state_dict(), f'mnist_{model_name}.pth')
    print(f"Model save to mnist_{model_name}")
    
    
    dummy_input_shape = (1, 1, 28, 28)  # MNIST图像为28x28像素，单通道
    dummy_input = torch.randn(tuple(dummy_input_shape)).to(device)  # 创建一个虚拟输入张量
    onnx_file_path = f"mnist_{model_name}_onnx.onnx"
    torch.onnx.export(model, dummy_input, onnx_file_path, export_params=True, opset_version=17, do_constant_folding=True)

    print(f"Train finished, Export Model to {onnx_file_path}")

    writer_epoch.close()
    
if __name__ == "__main__":
    main()