In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50
from typing import Optional
from torch.optim.lr_scheduler import StepLR
# 超参数
batch_size = 32
learning_rate = 0.001
num_epochs = 100
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pic_size=224
# 数据预处理和加载
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
    # transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),  # 随机擦除
    # transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),  # 颜色抖动
    transforms.Resize([pic_size,pic_size],antialias=False),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载训练和验证数据集
train_dataset = datasets.ImageFolder(root='MyGTData/train', transform=transform)
test_dataset = datasets.ImageFolder(root='MyGTData/test', transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,pin_memory=True,num_workers=0)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True,pin_memory=True,num_workers=0)
# torch.autograd.set_detect_anomaly(True)

from DF import DilateAttention,MultiDilatelocalAttention
from CAA import CAA,ConvModule
from LossFunction import FocalLoss,LabelSmoothingLoss

In [20]:
from torchvision.models import wide_resnet50_2
class ResNetWithDilateAttention(nn.Module):
    def __init__(self,channels,pic_size,nc):
        super(ResNetWithDilateAttention, self).__init__()
        # 使用ResNet50的预训练模型
        self.mdla=MultiDilatelocalAttention(dim=pic_size).cuda()
        self.caa=CAA(channels=channels).cuda()
        self.resnet = resnet50(weights=None,num_classes=nc)

    def forward(self, x):
        x=self.mdla(x)     
        x=self.caa(x)
        x = self.resnet(x)
        return x

# 实例化模型、定义损失函数和优化器
model = ResNetWithDilateAttention(channels=3,pic_size=pic_size,nc=4)
model.to(device)

cross_entropy_loss = nn.CrossEntropyLoss()
focal_loss = FocalLoss(alpha=1, gamma=2)
label_smoothing_loss = LabelSmoothingLoss(classes=4, smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

scheduler = StepLR(optimizer, step_size=10, gamma=0.8)


In [21]:
import torch
import logging
import time
from sklearn.metrics import f1_score
import numpy as np
from sklearn.metrics import roc_auc_score

# 配置日志记录
logging.basicConfig(filename='training.log', level=logging.INFO, format='%(asctime)s - %(message)s')

# 记录训练开始时间
start_time = time.time()

for epoch in range(1, num_epochs + 1):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        output = model(data)
        
        # 计算损失
        ce_loss = cross_entropy_loss(output, target)
        focal = focal_loss(output, target)
        ls_loss = label_smoothing_loss(output, target)
        loss = ce_loss + focal + ls_loss
        loss.backward()
        optimizer.step()

        # 每隔100个batch打印并记录日志
        if batch_idx % 100 == 0:
            log_message = f'Epoch {epoch}/{num_epochs}, Batch {batch_idx}, Total Loss: {loss.item():.4f}, Cross Entropy Loss: {ce_loss.item():.4f}, Focal Loss: {focal.item():.4f}, Label Smoothing Loss: {ls_loss.item():.4f}'
            print(log_message)
            logging.info(log_message)

    # 每隔10个epoch进行测试并记录
    if epoch % 10 == 0:
        correct = 0
        total = 0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                probabilities = torch.softmax(outputs, dim=1)
                _, predicted = torch.max(outputs.data, 1)

                total += target.size(0)
                correct += (predicted == target).sum().item()

                # 保存预测结果
                all_preds.extend(probabilities.cpu().numpy())
                all_targets.extend(target.cpu().numpy())

        # 计算准确率
        accuracy = 100 * correct / total
        log_message = f'Test Accuracy: {accuracy:.2f}%'
        print(log_message)
        logging.info(log_message)

        # 计算F1分数
        f1 = f1_score(all_targets, np.argmax(all_preds, axis=1), average='macro')
        log_message = f'F1 Score: {f1:.2f}'
        print(log_message)
        logging.info(log_message)
        
        
        auc = roc_auc_score(all_targets, np.max(all_preds, axis=1))
        log_message = f'ROC-AUC: {auc:.2f}'
        print(log_message)
        logging.info(log_message)


# 记录总训练时间
end_time = time.time()
total_time = end_time - start_time
log_message = f'Total training time: {total_time:.2f} seconds'
print(log_message)
logging.info(log_message)


Epoch 1/100, Batch 0, Total Loss: 5.6668, Cross Entropy Loss: 2.0706, Focal Loss: 1.5845, Label Smoothing Loss: 2.0118
Epoch 2/100, Batch 0, Total Loss: 3.0224, Cross Entropy Loss: 0.8488, Focal Loss: 0.8224, Label Smoothing Loss: 1.3512
Epoch 3/100, Batch 0, Total Loss: 1.2925, Cross Entropy Loss: 0.3772, Focal Loss: 0.1268, Label Smoothing Loss: 0.7885
Epoch 4/100, Batch 0, Total Loss: 1.2565, Cross Entropy Loss: 0.3291, Focal Loss: 0.1720, Label Smoothing Loss: 0.7554
Epoch 5/100, Batch 0, Total Loss: 0.8780, Cross Entropy Loss: 0.1923, Focal Loss: 0.0665, Label Smoothing Loss: 0.6193
Epoch 6/100, Batch 0, Total Loss: 0.8670, Cross Entropy Loss: 0.2476, Focal Loss: 0.0356, Label Smoothing Loss: 0.5838
Epoch 7/100, Batch 0, Total Loss: 0.8780, Cross Entropy Loss: 0.1976, Focal Loss: 0.0859, Label Smoothing Loss: 0.5945
Epoch 8/100, Batch 0, Total Loss: 0.7634, Cross Entropy Loss: 0.2080, Focal Loss: 0.0492, Label Smoothing Loss: 0.5062
Epoch 9/100, Batch 0, Total Loss: 0.8525, Cross 

In [27]:
print(transform)
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
    # transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),  # 随机擦除
    # transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),  # 颜色抖动
    transforms.Resize([pic_size,pic_size],antialias=False),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

ToPILImage()


In [None]:
torch.save(model,"./DF_CAA_ResNet.pth")
model.eval()
img_path=r"D:\Datas\01DWdatas\6kV西霞线\照片\6kV西霞线\IMG_4090.JPG"
print(model)
from PIL import Image, ImageFile
img = Image.open(img_path).convert('RGB')
import torchvision.transforms as T

input=transform(img).cuda()

input=torch.unsqueeze(input,0)
print(input.shape)
layer_name="ta"
# 假设你的模型叫model，输入是input
layer = model.mdla  # 替换为你想要的层名
with torch.no_grad():
    output = layer(input)

print(type(output))
output = output.squeeze(0)
print(output.shape)
import torchvision.transforms as transforms
transform = transforms.ToPILImage()
image = transform(output)

# 显示图像
image.show()
image.save("after_ta.jpg")
output=model(input)
print(output.shape)


ResNetWithDilateAttention(
  (mdla): MultiDilatelocalAttention(
    (qkv): Conv2d(224, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (dilate_attention): ModuleList(
      (0): DilateAttention(
        (unfold): Unfold(kernel_size=3, dilation=2, padding=2, stride=1)
        (attn_drop): Dropout(p=0.0, inplace=False)
      )
      (1): DilateAttention(
        (unfold): Unfold(kernel_size=3, dilation=3, padding=3, stride=1)
        (attn_drop): Dropout(p=0.0, inplace=False)
      )
    )
    (proj): Linear(in_features=224, out_features=224, bias=True)
    (proj_drop): Dropout(p=0.0, inplace=False)
  )
  (caa): CAA(
    (avg_pool): AvgPool2d(kernel_size=7, stride=1, padding=3)
    (conv1): ConvModule(
      (block): Sequential(
        (0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(3, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
    )
    (h_conv): ConvModule(
      (block): Se

In [None]:

#测试模型
import torch
from sklearn.metrics import f1_score
import numpy as np
for epoch in range(1,num_epochs+1):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        output = model(data)
        # Compute individual losses
        ce_loss = cross_entropy_loss(output, target)
        focal = focal_loss(output, target)
        ls_loss = label_smoothing_loss(output, target)
        loss = ce_loss +focal +ls_loss
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Epoch {epoch}/{num_epochs}, Batch {batch_idx},Total Loss: {loss.item():.4f},Cross Entropy Loss: {ce_loss.item():.4f},Focal Loss: {focal.item():.4f},Label Smoothing Loss: {ls_loss.item():.4f}')

    if epoch %10 ==0:
        correct = 0
        total = 0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                probabilities = torch.softmax(outputs, dim=1)
                _, predicted = torch.max(outputs.data, 1)

                total += target.size(0)
                correct += (predicted == target).sum().item()

                # 保存所有预测概率和真实标签
                all_preds.extend(probabilities.cpu().numpy())
                all_targets.extend(target.cpu().numpy())

        # 计算准确率
        accuracy = 100 * correct / total
        print(f'Test Accuracy: {accuracy:.2f}%')

        # 计算F1分数
        f1 = f1_score(all_targets, np.argmax(all_preds, axis=1), average='macro')
        print(f'F1 Score: {f1:.2f}')


In [37]:
import torch
from sklearn.metrics import f1_score, roc_curve, auc, roc_auc_score
import numpy as np
correct = 0
total = 0
all_preds = []
all_targets = []

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        outputs = model(data)
        probabilities = torch.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs.data, 1)
        
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        # 保存所有预测概率和真实标签
        all_preds.extend(probabilities.cpu().numpy())
        all_targets.extend(target.cpu().numpy())

# 计算准确率
accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

# 计算F1分数
f1 = f1_score(all_targets, np.argmax(all_preds, axis=1), average='macro')
print(f'F1 Score: {f1:.2f}')


Test Accuracy: 85.71%
F1 Score: 0.83


In [4]:
print(model)

ResNetWithDilateAttention(
  (mdla): MultiDilatelocalAttention(
    (qkv): Conv2d(224, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (dilate_attention): ModuleList(
      (0): DilateAttention(
        (unfold): Unfold(kernel_size=3, dilation=2, padding=2, stride=1)
        (attn_drop): Dropout(p=0.0, inplace=False)
      )
      (1): DilateAttention(
        (unfold): Unfold(kernel_size=3, dilation=3, padding=3, stride=1)
        (attn_drop): Dropout(p=0.0, inplace=False)
      )
    )
    (proj): Linear(in_features=224, out_features=224, bias=True)
    (proj_drop): Dropout(p=0.0, inplace=False)
  )
  (caa): CAA(
    (avg_pool): AvgPool2d(kernel_size=7, stride=1, padding=3)
    (conv1): ConvModule(
      (block): Sequential(
        (0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(3, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
    )
    (h_conv): ConvModule(
      (block): Se

In [9]:
DF_CAA_ResNet=r"C:\Users\gaoge\Desktop\GHOME\01projs\03mycodes\tmp\12杆塔分类\model.pth"
DF_CAA_ResNet=torch.load(DF_CAA_ResNet)
model.load_state_dict(DF_CAA_ResNet)
with torch.no_grad():
    model.eval()

RuntimeError: Error(s) in loading state_dict for ResNetWithDilateAttention:
	Missing key(s) in state_dict: "mdla.qkv.weight", "mdla.proj.weight", "mdla.proj.bias", "caa.conv1.block.0.weight", "caa.conv1.block.1.weight", "caa.conv1.block.1.bias", "caa.conv1.block.1.running_mean", "caa.conv1.block.1.running_var", "caa.h_conv.block.0.weight", "caa.h_conv.block.0.bias", "caa.v_conv.block.0.weight", "caa.v_conv.block.0.bias", "caa.conv2.block.0.weight", "caa.conv2.block.1.weight", "caa.conv2.block.1.bias", "caa.conv2.block.1.running_mean", "caa.conv2.block.1.running_var", "resnet.conv1.weight", "resnet.bn1.weight", "resnet.bn1.bias", "resnet.bn1.running_mean", "resnet.bn1.running_var", "resnet.layer1.0.conv1.weight", "resnet.layer1.0.bn1.weight", "resnet.layer1.0.bn1.bias", "resnet.layer1.0.bn1.running_mean", "resnet.layer1.0.bn1.running_var", "resnet.layer1.0.conv2.weight", "resnet.layer1.0.bn2.weight", "resnet.layer1.0.bn2.bias", "resnet.layer1.0.bn2.running_mean", "resnet.layer1.0.bn2.running_var", "resnet.layer1.0.conv3.weight", "resnet.layer1.0.bn3.weight", "resnet.layer1.0.bn3.bias", "resnet.layer1.0.bn3.running_mean", "resnet.layer1.0.bn3.running_var", "resnet.layer1.0.downsample.0.weight", "resnet.layer1.0.downsample.1.weight", "resnet.layer1.0.downsample.1.bias", "resnet.layer1.0.downsample.1.running_mean", "resnet.layer1.0.downsample.1.running_var", "resnet.layer1.1.conv1.weight", "resnet.layer1.1.bn1.weight", "resnet.layer1.1.bn1.bias", "resnet.layer1.1.bn1.running_mean", "resnet.layer1.1.bn1.running_var", "resnet.layer1.1.conv2.weight", "resnet.layer1.1.bn2.weight", "resnet.layer1.1.bn2.bias", "resnet.layer1.1.bn2.running_mean", "resnet.layer1.1.bn2.running_var", "resnet.layer1.1.conv3.weight", "resnet.layer1.1.bn3.weight", "resnet.layer1.1.bn3.bias", "resnet.layer1.1.bn3.running_mean", "resnet.layer1.1.bn3.running_var", "resnet.layer1.2.conv1.weight", "resnet.layer1.2.bn1.weight", "resnet.layer1.2.bn1.bias", "resnet.layer1.2.bn1.running_mean", "resnet.layer1.2.bn1.running_var", "resnet.layer1.2.conv2.weight", "resnet.layer1.2.bn2.weight", "resnet.layer1.2.bn2.bias", "resnet.layer1.2.bn2.running_mean", "resnet.layer1.2.bn2.running_var", "resnet.layer1.2.conv3.weight", "resnet.layer1.2.bn3.weight", "resnet.layer1.2.bn3.bias", "resnet.layer1.2.bn3.running_mean", "resnet.layer1.2.bn3.running_var", "resnet.layer2.0.conv1.weight", "resnet.layer2.0.bn1.weight", "resnet.layer2.0.bn1.bias", "resnet.layer2.0.bn1.running_mean", "resnet.layer2.0.bn1.running_var", "resnet.layer2.0.conv2.weight", "resnet.layer2.0.bn2.weight", "resnet.layer2.0.bn2.bias", "resnet.layer2.0.bn2.running_mean", "resnet.layer2.0.bn2.running_var", "resnet.layer2.0.conv3.weight", "resnet.layer2.0.bn3.weight", "resnet.layer2.0.bn3.bias", "resnet.layer2.0.bn3.running_mean", "resnet.layer2.0.bn3.running_var", "resnet.layer2.0.downsample.0.weight", "resnet.layer2.0.downsample.1.weight", "resnet.layer2.0.downsample.1.bias", "resnet.layer2.0.downsample.1.running_mean", "resnet.layer2.0.downsample.1.running_var", "resnet.layer2.1.conv1.weight", "resnet.layer2.1.bn1.weight", "resnet.layer2.1.bn1.bias", "resnet.layer2.1.bn1.running_mean", "resnet.layer2.1.bn1.running_var", "resnet.layer2.1.conv2.weight", "resnet.layer2.1.bn2.weight", "resnet.layer2.1.bn2.bias", "resnet.layer2.1.bn2.running_mean", "resnet.layer2.1.bn2.running_var", "resnet.layer2.1.conv3.weight", "resnet.layer2.1.bn3.weight", "resnet.layer2.1.bn3.bias", "resnet.layer2.1.bn3.running_mean", "resnet.layer2.1.bn3.running_var", "resnet.layer2.2.conv1.weight", "resnet.layer2.2.bn1.weight", "resnet.layer2.2.bn1.bias", "resnet.layer2.2.bn1.running_mean", "resnet.layer2.2.bn1.running_var", "resnet.layer2.2.conv2.weight", "resnet.layer2.2.bn2.weight", "resnet.layer2.2.bn2.bias", "resnet.layer2.2.bn2.running_mean", "resnet.layer2.2.bn2.running_var", "resnet.layer2.2.conv3.weight", "resnet.layer2.2.bn3.weight", "resnet.layer2.2.bn3.bias", "resnet.layer2.2.bn3.running_mean", "resnet.layer2.2.bn3.running_var", "resnet.layer2.3.conv1.weight", "resnet.layer2.3.bn1.weight", "resnet.layer2.3.bn1.bias", "resnet.layer2.3.bn1.running_mean", "resnet.layer2.3.bn1.running_var", "resnet.layer2.3.conv2.weight", "resnet.layer2.3.bn2.weight", "resnet.layer2.3.bn2.bias", "resnet.layer2.3.bn2.running_mean", "resnet.layer2.3.bn2.running_var", "resnet.layer2.3.conv3.weight", "resnet.layer2.3.bn3.weight", "resnet.layer2.3.bn3.bias", "resnet.layer2.3.bn3.running_mean", "resnet.layer2.3.bn3.running_var", "resnet.layer3.0.conv1.weight", "resnet.layer3.0.bn1.weight", "resnet.layer3.0.bn1.bias", "resnet.layer3.0.bn1.running_mean", "resnet.layer3.0.bn1.running_var", "resnet.layer3.0.conv2.weight", "resnet.layer3.0.bn2.weight", "resnet.layer3.0.bn2.bias", "resnet.layer3.0.bn2.running_mean", "resnet.layer3.0.bn2.running_var", "resnet.layer3.0.conv3.weight", "resnet.layer3.0.bn3.weight", "resnet.layer3.0.bn3.bias", "resnet.layer3.0.bn3.running_mean", "resnet.layer3.0.bn3.running_var", "resnet.layer3.0.downsample.0.weight", "resnet.layer3.0.downsample.1.weight", "resnet.layer3.0.downsample.1.bias", "resnet.layer3.0.downsample.1.running_mean", "resnet.layer3.0.downsample.1.running_var", "resnet.layer3.1.conv1.weight", "resnet.layer3.1.bn1.weight", "resnet.layer3.1.bn1.bias", "resnet.layer3.1.bn1.running_mean", "resnet.layer3.1.bn1.running_var", "resnet.layer3.1.conv2.weight", "resnet.layer3.1.bn2.weight", "resnet.layer3.1.bn2.bias", "resnet.layer3.1.bn2.running_mean", "resnet.layer3.1.bn2.running_var", "resnet.layer3.1.conv3.weight", "resnet.layer3.1.bn3.weight", "resnet.layer3.1.bn3.bias", "resnet.layer3.1.bn3.running_mean", "resnet.layer3.1.bn3.running_var", "resnet.layer3.2.conv1.weight", "resnet.layer3.2.bn1.weight", "resnet.layer3.2.bn1.bias", "resnet.layer3.2.bn1.running_mean", "resnet.layer3.2.bn1.running_var", "resnet.layer3.2.conv2.weight", "resnet.layer3.2.bn2.weight", "resnet.layer3.2.bn2.bias", "resnet.layer3.2.bn2.running_mean", "resnet.layer3.2.bn2.running_var", "resnet.layer3.2.conv3.weight", "resnet.layer3.2.bn3.weight", "resnet.layer3.2.bn3.bias", "resnet.layer3.2.bn3.running_mean", "resnet.layer3.2.bn3.running_var", "resnet.layer3.3.conv1.weight", "resnet.layer3.3.bn1.weight", "resnet.layer3.3.bn1.bias", "resnet.layer3.3.bn1.running_mean", "resnet.layer3.3.bn1.running_var", "resnet.layer3.3.conv2.weight", "resnet.layer3.3.bn2.weight", "resnet.layer3.3.bn2.bias", "resnet.layer3.3.bn2.running_mean", "resnet.layer3.3.bn2.running_var", "resnet.layer3.3.conv3.weight", "resnet.layer3.3.bn3.weight", "resnet.layer3.3.bn3.bias", "resnet.layer3.3.bn3.running_mean", "resnet.layer3.3.bn3.running_var", "resnet.layer3.4.conv1.weight", "resnet.layer3.4.bn1.weight", "resnet.layer3.4.bn1.bias", "resnet.layer3.4.bn1.running_mean", "resnet.layer3.4.bn1.running_var", "resnet.layer3.4.conv2.weight", "resnet.layer3.4.bn2.weight", "resnet.layer3.4.bn2.bias", "resnet.layer3.4.bn2.running_mean", "resnet.layer3.4.bn2.running_var", "resnet.layer3.4.conv3.weight", "resnet.layer3.4.bn3.weight", "resnet.layer3.4.bn3.bias", "resnet.layer3.4.bn3.running_mean", "resnet.layer3.4.bn3.running_var", "resnet.layer3.5.conv1.weight", "resnet.layer3.5.bn1.weight", "resnet.layer3.5.bn1.bias", "resnet.layer3.5.bn1.running_mean", "resnet.layer3.5.bn1.running_var", "resnet.layer3.5.conv2.weight", "resnet.layer3.5.bn2.weight", "resnet.layer3.5.bn2.bias", "resnet.layer3.5.bn2.running_mean", "resnet.layer3.5.bn2.running_var", "resnet.layer3.5.conv3.weight", "resnet.layer3.5.bn3.weight", "resnet.layer3.5.bn3.bias", "resnet.layer3.5.bn3.running_mean", "resnet.layer3.5.bn3.running_var", "resnet.layer4.0.conv1.weight", "resnet.layer4.0.bn1.weight", "resnet.layer4.0.bn1.bias", "resnet.layer4.0.bn1.running_mean", "resnet.layer4.0.bn1.running_var", "resnet.layer4.0.conv2.weight", "resnet.layer4.0.bn2.weight", "resnet.layer4.0.bn2.bias", "resnet.layer4.0.bn2.running_mean", "resnet.layer4.0.bn2.running_var", "resnet.layer4.0.conv3.weight", "resnet.layer4.0.bn3.weight", "resnet.layer4.0.bn3.bias", "resnet.layer4.0.bn3.running_mean", "resnet.layer4.0.bn3.running_var", "resnet.layer4.0.downsample.0.weight", "resnet.layer4.0.downsample.1.weight", "resnet.layer4.0.downsample.1.bias", "resnet.layer4.0.downsample.1.running_mean", "resnet.layer4.0.downsample.1.running_var", "resnet.layer4.1.conv1.weight", "resnet.layer4.1.bn1.weight", "resnet.layer4.1.bn1.bias", "resnet.layer4.1.bn1.running_mean", "resnet.layer4.1.bn1.running_var", "resnet.layer4.1.conv2.weight", "resnet.layer4.1.bn2.weight", "resnet.layer4.1.bn2.bias", "resnet.layer4.1.bn2.running_mean", "resnet.layer4.1.bn2.running_var", "resnet.layer4.1.conv3.weight", "resnet.layer4.1.bn3.weight", "resnet.layer4.1.bn3.bias", "resnet.layer4.1.bn3.running_mean", "resnet.layer4.1.bn3.running_var", "resnet.layer4.2.conv1.weight", "resnet.layer4.2.bn1.weight", "resnet.layer4.2.bn1.bias", "resnet.layer4.2.bn1.running_mean", "resnet.layer4.2.bn1.running_var", "resnet.layer4.2.conv2.weight", "resnet.layer4.2.bn2.weight", "resnet.layer4.2.bn2.bias", "resnet.layer4.2.bn2.running_mean", "resnet.layer4.2.bn2.running_var", "resnet.layer4.2.conv3.weight", "resnet.layer4.2.bn3.weight", "resnet.layer4.2.bn3.bias", "resnet.layer4.2.bn3.running_mean", "resnet.layer4.2.bn3.running_var", "resnet.fc.weight", "resnet.fc.bias". 
	Unexpected key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "bn1.num_batches_tracked", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.bn1.num_batches_tracked", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.bn2.num_batches_tracked", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.bn3.num_batches_tracked", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.0.downsample.1.num_batches_tracked", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.bn1.num_batches_tracked", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.bn2.num_batches_tracked", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.1.bn3.num_batches_tracked", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.bn1.num_batches_tracked", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.bn2.num_batches_tracked", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer1.2.bn3.num_batches_tracked", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.bn1.num_batches_tracked", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.bn2.num_batches_tracked", "layer2.0.conv3.weight", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.0.bn3.num_batches_tracked", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.0.downsample.1.num_batches_tracked", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.bn1.num_batches_tracked", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.1.bn2.num_batches_tracked", "layer2.1.conv3.weight", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.1.bn3.num_batches_tracked", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.bn1.num_batches_tracked", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.bn2.num_batches_tracked", "layer2.2.conv3.weight", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.2.bn3.num_batches_tracked", "layer2.3.conv1.weight", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.bn1.num_batches_tracked", "layer2.3.conv2.weight", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer2.3.bn2.num_batches_tracked", "layer2.3.conv3.weight", "layer2.3.bn3.weight", "layer2.3.bn3.bias", "layer2.3.bn3.running_mean", "layer2.3.bn3.running_var", "layer2.3.bn3.num_batches_tracked", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.bn1.num_batches_tracked", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.bn2.num_batches_tracked", "layer3.0.conv3.weight", "layer3.0.bn3.weight", "layer3.0.bn3.bias", "layer3.0.bn3.running_mean", "layer3.0.bn3.running_var", "layer3.0.bn3.num_batches_tracked", "layer3.0.downsample.0.weight", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.0.downsample.1.num_batches_tracked", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.bn1.num_batches_tracked", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.1.bn2.num_batches_tracked", "layer3.1.conv3.weight", "layer3.1.bn3.weight", "layer3.1.bn3.bias", "layer3.1.bn3.running_mean", "layer3.1.bn3.running_var", "layer3.1.bn3.num_batches_tracked", "layer3.2.conv1.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.bn1.num_batches_tracked", "layer3.2.conv2.weight", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.2.bn2.num_batches_tracked", "layer3.2.conv3.weight", "layer3.2.bn3.weight", "layer3.2.bn3.bias", "layer3.2.bn3.running_mean", "layer3.2.bn3.running_var", "layer3.2.bn3.num_batches_tracked", "layer3.3.conv1.weight", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.bn1.num_batches_tracked", "layer3.3.conv2.weight", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.3.bn2.num_batches_tracked", "layer3.3.conv3.weight", "layer3.3.bn3.weight", "layer3.3.bn3.bias", "layer3.3.bn3.running_mean", "layer3.3.bn3.running_var", "layer3.3.bn3.num_batches_tracked", "layer3.4.conv1.weight", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.bn1.num_batches_tracked", "layer3.4.conv2.weight", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.4.bn2.num_batches_tracked", "layer3.4.conv3.weight", "layer3.4.bn3.weight", "layer3.4.bn3.bias", "layer3.4.bn3.running_mean", "layer3.4.bn3.running_var", "layer3.4.bn3.num_batches_tracked", "layer3.5.conv1.weight", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.bn1.num_batches_tracked", "layer3.5.conv2.weight", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer3.5.bn2.num_batches_tracked", "layer3.5.conv3.weight", "layer3.5.bn3.weight", "layer3.5.bn3.bias", "layer3.5.bn3.running_mean", "layer3.5.bn3.running_var", "layer3.5.bn3.num_batches_tracked", "layer4.0.conv1.weight", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn1.num_batches_tracked", "layer4.0.conv2.weight", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.bn2.num_batches_tracked", "layer4.0.conv3.weight", "layer4.0.bn3.weight", "layer4.0.bn3.bias", "layer4.0.bn3.running_mean", "layer4.0.bn3.running_var", "layer4.0.bn3.num_batches_tracked", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.0.downsample.1.num_batches_tracked", "layer4.1.conv1.weight", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn1.num_batches_tracked", "layer4.1.conv2.weight", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.bn2.num_batches_tracked", "layer4.1.conv3.weight", "layer4.1.bn3.weight", "layer4.1.bn3.bias", "layer4.1.bn3.running_mean", "layer4.1.bn3.running_var", "layer4.1.bn3.num_batches_tracked", "layer4.2.conv1.weight", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.bn1.num_batches_tracked", "layer4.2.conv2.weight", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.bn2.num_batches_tracked", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var", "layer4.2.bn3.num_batches_tracked", "fc.weight", "fc.bias". 

In [5]:

import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from torchcam.methods import GradCAM
from torchcam.utils import overlay_mask

# Load and preprocess the image
image_path = r"C:\Users\gaoge\Desktop\GHOME\01projs\03mycodes\tmp\12杆塔分类\MyGTData\train\直线杆\IMG_1498.JPG"
image = Image.open(image_path)

# Define the image transformations (resize, normalize)
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
    # transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),  # 随机擦除
    # transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),  # 颜色抖动
    transforms.Resize([pic_size,pic_size],antialias=False),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Apply transformations and add batch dimension
input_tensor = transform(image).unsqueeze(0).cuda()

# Load your model
model.eval()  # Set model to evaluation mode

# Initialize GradCAM with the target layer (e.g., last conv layer in ResNet)
cam_extractor = GradCAM(model, target_layer="resnet.layer4")  # Change the layer to the appropriate one

# Forward pass to get the model's output
output = model(input_tensor)

# Generate the CAM for the target class (if you know the class index, use it; otherwise, pick the predicted class)
activation_map = cam_extractor(output.squeeze(0).argmax().item(), output)
tmp=transforms.ToPILImage()(activation_map[0])
# Visualize the heatmap
# Convert the image to a format suitable for overlay
rgb_img = image  # Convert back to PIL for visualization

# Overlay the heatmap on the original image
result = overlay_mask(rgb_img, tmp, alpha=0.8)


# Display the result
plt.imshow(result)
plt.axis('off')
plt.show()


TypeError: Expected state_dict to be dict-like, got <class 'str'>.

In [53]:
# -*- coding: utf-8 -*-
"""
Created on 2019/8/4 上午9:53

@author: mick.yi

入口类

"""
import argparse
import os
import re

import cv2
import numpy as np
import torch
from skimage import io
from torch import nn
from torchvision import models

from interpretability.grad_cam import GradCAM, GradCamPlusPlus
from interpretability.guided_back_propagation import GuidedBackPropagation


def get_last_conv_name(net):
    """
    获取网络的最后一个卷积层的名字
    :param net:
    :return:
    """
    layer_name = None
    for name, m in net.named_modules():
        if isinstance(m, nn.Conv2d):
            layer_name = name
    return layer_name


def prepare_input(image):
    image = image.copy()

    # 归一化
    means = np.array([0.485, 0.456, 0.406])
    stds = np.array([0.229, 0.224, 0.225])
    image -= means
    image /= stds

    image = np.ascontiguousarray(np.transpose(image, (2, 0, 1)))  # channel first
    image = image[np.newaxis, ...]  # 增加batch维

    return torch.tensor(image, requires_grad=True)


def gen_cam(image, mask):
    """
    生成CAM图
    :param image: [H,W,C],原始图像
    :param mask: [H,W],范围0~1
    :return: tuple(cam,heatmap)
    """
    # mask转为heatmap
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    heatmap = heatmap[..., ::-1]  # gbr to rgb

    # 合并heatmap到原始图像
    cam = heatmap + np.float32(image)
    return norm_image(cam), (heatmap * 255).astype(np.uint8)


def norm_image(image):
    """
    标准化图像
    :param image: [H,W,C]
    :return:
    """
    image = image.copy()
    image -= np.max(np.min(image), 0)
    image /= np.max(image)
    image *= 255.
    return np.uint8(image)


def gen_gb(grad):
    """
    生guided back propagation 输入图像的梯度
    :param grad: tensor,[3,H,W]
    :return:
    """
    # 标准化
    grad = grad.data.numpy()
    gb = np.transpose(grad, (1, 2, 0))
    return gb


def save_image(image_dicts, input_image_name, network, output_dir):
    prefix = os.path.splitext(input_image_name)[0]
    for key, image in image_dicts.items():
        io.imsave(os.path.join(output_dir, '{}-{}-{}.jpg'.format(prefix, 'tmp', key)), image)


image_path=r"C:\Users\gaoge\Desktop\GHOME\01projs\03mycodes\tmp\12杆塔分类\MyGTData\train\耐张杆\IMG_1510.JPG"
network=model
weight_path=None
layer_name='resnet.layer4'
class_id=None
output_dir="results"
# 输入
img = io.imread(image_path)
img = np.float32(cv2.resize(img, (224, 224))) / 255
inputs = prepare_input(img)
# 输出图像
image_dict = {}
# 网络
net = network.to('cpu')
# Grad-CAM
layer_name = get_last_conv_name(net) if layer_name is None else layer_name
grad_cam = GradCAM(net, layer_name)
mask = grad_cam(inputs, class_id)  # cam mask
image_dict['cam'], image_dict['heatmap'] = gen_cam(img, mask)
grad_cam.remove_handlers()
# Grad-CAM++
grad_cam_plus_plus = GradCamPlusPlus(net, layer_name)
mask_plus_plus = grad_cam_plus_plus(inputs, class_id)  # cam mask
image_dict['cam++'], image_dict['heatmap++'] = gen_cam(img, mask_plus_plus)
grad_cam_plus_plus.remove_handlers()

# GuidedBackPropagation
gbp = GuidedBackPropagation(net)
inputs.grad.zero_()  # 梯度置零
grad = gbp(inputs)

gb = gen_gb(grad)
image_dict['gb'] = norm_image(gb)
# 生成Guided Grad-CAM
cam_gb = gb * mask[..., np.newaxis]
image_dict['cam_gb'] = norm_image(cam_gb)

save_image(image_dict, os.path.basename(image_path), network, output_dir)



feature shape:torch.Size([1, 2048, 7, 7])
feature shape:torch.Size([1, 2048, 7, 7])


  io.imsave(os.path.join(output_dir, '{}-{}-{}.jpg'.format(prefix, 'tmp', key)), image)
  io.imsave(os.path.join(output_dir, '{}-{}-{}.jpg'.format(prefix, 'tmp', key)), image)
