## 导入工具包

In [1]:
import os
from tqdm import tqdm

import numpy as np
import pandas as pd

from PIL import Image

import torch
import torch.nn.functional as F

# 有 GPU 就用 GPU，没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

device cuda:0


## 图像预处理

In [2]:
from torchvision import transforms

# # 训练集图像预处理：缩放裁剪、图像增强、转 Tensor、归一化
# train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
#                                       transforms.RandomHorizontalFlip(),
#                                       transforms.ToTensor(),
#                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#                                      ])

# 测试集图像预处理-RCTN：缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

## 载入测试集（和训练代码教程相同）

In [3]:
# 数据集文件夹路径
dataset_dir = '/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb/cyy/work/图像分类/cutsave_split'
test_path = os.path.join(dataset_dir, 'val')
from torchvision import datasets
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)
print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)
# 载入类别名称 和 ID索引号 的映射字典
idx_to_labels =np.load('/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb/cyy/work/图像分类/3-【Pytorch】迁移学习训练自己的图像分类模型/idx_to_labels.npy', allow_pickle=True).item()
# 获得类别名称
classes = list(idx_to_labels.values())
print(classes)

测试集图像数量 3200
类别个数 4
各类别名称 ['medium', 'none', 'strong', 'weak']
['medium', 'none', 'strong', 'weak']


## 导入训练好的模型

In [4]:
model = torch.load('/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb/cyy/work/图像分类/3-【Pytorch】迁移学习训练自己的图像分类模型/checkpoints/best-0.926.pth')
model = model.eval().to(device)

## 表格A-测试集图像路径及标注

In [5]:
test_dataset.imgs[:10]

[('/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb/cyy/work/图像分类/cutsave_split/val/medium/medium00002.jpg',
  0),
 ('/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb/cyy/work/图像分类/cutsave_split/val/medium/medium00003.jpg',
  0),
 ('/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb/cyy/work/图像分类/cutsave_split/val/medium/medium00006.jpg',
  0),
 ('/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb/cyy/work/图像分类/cutsave_split/val/medium/medium00009.jpg',
  0),
 ('/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb/cyy/work/图像分类/cutsave_split/val/medium/medium00017.jpg',
  0),
 ('/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb/cyy/work/图像分类/cutsave_split/val/medium/medium00022.jpg',
  0),
 ('/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb/cyy/work/图像分类/cutsave_split/val/medium/medium00028.jpg',
  0),
 ('/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb/cyy/work/图像分类/cutsave_split/val/medium/medium00044.jpg',
  0),
 ('/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb/cyy/work/图像分类/cutsave_split/val/medium/medium0

In [6]:
img_paths = [each[0] for each in test_dataset.imgs]

In [7]:
df = pd.DataFrame()
df['图像路径'] = img_paths
df['标注类别ID'] = test_dataset.targets
df['标注类别名称'] = [idx_to_labels[ID] for ID in test_dataset.targets]

In [8]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称
0,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium
1,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium
2,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium
3,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium
4,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium
5,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium
6,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium
7,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium
8,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium
9,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium


## 表格B-测试集每张图像的图像分类预测结果，以及各类别置信度

In [9]:
# 记录 top-n 预测结果
n = 3

In [10]:
df_pred = pd.DataFrame()
for idx, row in tqdm(df.iterrows()):
    img_path = row['图像路径']
    img_pil = Image.open(img_path).convert('RGB')
    input_img = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
    pred_logits = model(input_img) # 执行前向预测，得到所有类别的 logit 预测分数
    pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算

    pred_dict = {}

    top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
    
    # top-n 预测结果
    for i in range(1, n+1):
        pred_dict['top-{}-预测ID'.format(i)] = pred_ids[i-1]
        pred_dict['top-{}-预测名称'.format(i)] = idx_to_labels[pred_ids[i-1]]
    pred_dict['top-n预测正确'] = row['标注类别ID'] in pred_ids
    # 每个类别的预测置信度
    for idx, each in enumerate(classes):
        pred_dict['{}-预测置信度'.format(each)] = pred_softmax[0][idx].cpu().detach().numpy()
        
    df_pred = df_pred.append(pred_dict, ignore_index=True)
    

3200it [02:41, 19.84it/s]


In [11]:
df_pred

Unnamed: 0,medium-预测置信度,none-预测置信度,strong-预测置信度,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,weak-预测置信度
0,0.7357371,0.0456867,0.08210941,0.0,medium,3.0,weak,2.0,strong,1.0,0.1364668
1,0.45540452,0.01306144,0.35999188,0.0,medium,2.0,strong,3.0,weak,1.0,0.1715422
2,0.7578047,0.009940599,0.17786004,0.0,medium,2.0,strong,3.0,weak,1.0,0.05439465
3,0.34346557,0.010454647,0.6322077,2.0,strong,0.0,medium,3.0,weak,1.0,0.013872127
4,0.4734628,0.0048073693,0.5154145,2.0,strong,0.0,medium,3.0,weak,1.0,0.0063154167
5,0.7664874,0.012169804,0.20866151,0.0,medium,2.0,strong,3.0,weak,1.0,0.012681235
6,0.87741566,0.0052705775,0.10027949,0.0,medium,2.0,strong,3.0,weak,1.0,0.017034208
7,0.5865845,0.010732329,0.32103476,0.0,medium,2.0,strong,3.0,weak,1.0,0.081648424
8,0.80665386,0.014740202,0.1501817,0.0,medium,2.0,strong,3.0,weak,1.0,0.028424254
9,0.87622076,0.0043690037,0.112622544,0.0,medium,2.0,strong,3.0,weak,1.0,0.0067877215


## 拼接AB两张表格

In [12]:
df = pd.concat([df, df_pred], axis=1)

In [13]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称,medium-预测置信度,none-预测置信度,strong-预测置信度,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,weak-预测置信度
0,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium,0.7357371,0.0456867,0.08210941,0.0,medium,3.0,weak,2.0,strong,1.0,0.1364668
1,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium,0.45540452,0.01306144,0.35999188,0.0,medium,2.0,strong,3.0,weak,1.0,0.1715422
2,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium,0.7578047,0.009940599,0.17786004,0.0,medium,2.0,strong,3.0,weak,1.0,0.05439465
3,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium,0.34346557,0.010454647,0.6322077,2.0,strong,0.0,medium,3.0,weak,1.0,0.013872127
4,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium,0.4734628,0.0048073693,0.5154145,2.0,strong,0.0,medium,3.0,weak,1.0,0.0063154167
5,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium,0.7664874,0.012169804,0.20866151,0.0,medium,2.0,strong,3.0,weak,1.0,0.012681235
6,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium,0.87741566,0.0052705775,0.10027949,0.0,medium,2.0,strong,3.0,weak,1.0,0.017034208
7,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium,0.5865845,0.010732329,0.32103476,0.0,medium,2.0,strong,3.0,weak,1.0,0.081648424
8,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium,0.80665386,0.014740202,0.1501817,0.0,medium,2.0,strong,3.0,weak,1.0,0.028424254
9,/media/zy/600b4f4d-dc6c-4b45-8c58-476ec9afb3fb...,0,medium,0.87622076,0.0043690037,0.112622544,0.0,medium,2.0,strong,3.0,weak,1.0,0.0067877215


## 导出完整表格

In [14]:
df.to_csv('测试集预测结果new VGG16.csv', index=False)