# COMP 9517


# Error Analysis

By Zijun Zhou @ z5593866

In [None]:
!pip install pyyaml



In [None]:
# Avoid conflicts!
!pip install -q captum==0.7.0 --no-deps

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━[0m [32m0.8/1.3 MB[0m [31m25.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# 1: 基本配置 & 安装依赖

ROOT = "/content/9517_data"

import os, shutil, random
from pathlib import Path
import numpy as np

os.makedirs(ROOT, exist_ok=True)
%cd $ROOT

print("当前工作目录:", os.getcwd())
print("目录下文件:\n", os.listdir("."))
print("当前 numpy 版本:", np.__version__)


/content/9517_data
当前工作目录: /content/9517_data
目录下文件:
 ['crop_from_yolo_labels.py', 'train', 'test', 'data.yaml', 'resnet50_cls_best.pth', 'valid']
当前 numpy 版本: 2.0.2


In [None]:
# 2: 重新生成分类数据集 cls_data/（train + valid）
import os, shutil

# 如果之前有旧的 cls_data，先删掉，避免混乱
cls_root = os.path.join(ROOT, "cls_data")
if os.path.exists(cls_root):
    shutil.rmtree(cls_root)
    print("已删除旧的 cls_data/")

assert os.path.exists("crop_from_yolo_labels.py"), "找不到 crop_from_yolo_labels.py，请确认放在 ROOT 目录"

!python crop_from_yolo_labels.py

print("\n重新生成后的 cls_data 结构：")
!find cls_data -maxdepth 2 -type d | sort


[CROP] 开始裁剪 train，共 11502 张图像 ...
[CROP] 完成 train 裁剪。
[CROP] 开始裁剪 valid，共 1095 张图像 ...
[CROP] 完成 valid 裁剪。
完成：cls_data/{train,valid}/<类名>/

重新生成后的 cls_data 结构：
cls_data
cls_data/train
cls_data/train/Ants
cls_data/train/Bees
cls_data/train/Beetles
cls_data/train/Caterpillars
cls_data/train/Earthworms
cls_data/train/Earwigs
cls_data/train/Grasshoppers
cls_data/train/Moths
cls_data/train/Slugs
cls_data/train/Snails
cls_data/train/Wasps
cls_data/train/Weevils
cls_data/valid
cls_data/valid/Ants
cls_data/valid/Bees
cls_data/valid/Beetles
cls_data/valid/Caterpillars
cls_data/valid/Earthworms
cls_data/valid/Earwigs
cls_data/valid/Grasshoppers
cls_data/valid/Moths
cls_data/valid/Slugs
cls_data/valid/Snails
cls_data/valid/Wasps
cls_data/valid/Weevils


In [None]:
# 3: 加载训练好的 ResNet50 分类器

import torch
import torch.nn as nn
from torchvision import models, transforms, datasets

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("使用设备:", device)

ckpt_path = Path("resnet50_cls_best.pth")
assert ckpt_path.exists(), "找不到 resnet50_cls_best.pth，请确认上传到 ROOT 目录"

ckpt = torch.load(ckpt_path, map_location=device)
assert "state_dict" in ckpt and "classes" in ckpt, "权重文件中缺少 state_dict 或 classes"

ckpt_classes = ckpt["classes"]
num_classes = len(ckpt_classes)
print("从 checkpoint 中恢复的类别列表:", ckpt_classes)

# 构建 ResNet50
model = models.resnet50(weights=None)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
model.load_state_dict(ckpt["state_dict"])
model.to(device)
model.eval()

print("ResNet50 分类器已加载完成。")


使用设备: cpu
从 checkpoint 中恢复的类别列表: ['Ants', 'Bees', 'Beetles', 'Caterpillars', 'Earthworms', 'Earwigs', 'Grasshoppers', 'Moths', 'Slugs', 'Snails', 'Wasps', 'Weevils']
ResNet50 分类器已加载完成。


In [None]:
# 4: 构建 cls_data/valid 的 Dataset & Dataloader

from torch.utils.data import DataLoader

IMGSZ = 224

eval_tf = transforms.Compose([
    transforms.Resize(int(IMGSZ * 1.15)),
    transforms.CenterCrop(IMGSZ),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

val_dir = os.path.join(ROOT, "cls_data", "valid")
assert os.path.isdir(val_dir), f"找不到目录: {val_dir}"

val_dataset = datasets.ImageFolder(val_dir, transform=eval_tf)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

print(f"验证集样本数: {len(val_dataset)}")
print("ImageFolder 解析出的类别:", val_dataset.classes)

# 确认 folder 名顺序和 ckpt 里的 classes 完全一致
assert val_dataset.classes == ckpt_classes, \
    f"类名顺序不一致！ImageFolder={val_dataset.classes}, ckpt={ckpt_classes}"
print("✔ 类别顺序与训练时一致，可以放心比较预测 vs GT。")


验证集样本数: 1341
ImageFolder 解析出的类别: ['Ants', 'Bees', 'Beetles', 'Caterpillars', 'Earthworms', 'Earwigs', 'Grasshoppers', 'Moths', 'Slugs', 'Snails', 'Wasps', 'Weevils']
✔ 类别顺序与训练时一致，可以放心比较预测 vs GT。


In [None]:
# 5: 在 cls_data/valid 上跑一遍 ResNet，收集预测结果

import torch
from tqdm import tqdm

records = []  # 每个元素: dict(path, true_idx, true_name, pred_idx, pred_name, prob)

model.eval()
with torch.no_grad():
    base_idx = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(val_loader, desc="推理 valid")):
        inputs = inputs.to(device)
        targets = targets.to(device)

        logits = model(inputs)
        probs = torch.softmax(logits, dim=1)
        confs, preds = probs.max(dim=1)

        batch_size = inputs.size(0)
        for i in range(batch_size):
            ds_idx = batch_idx * val_loader.batch_size + i
            if ds_idx >= len(val_dataset):
                continue
            img_path, _ = val_dataset.samples[ds_idx]
            true_idx = int(targets[i].item())
            pred_idx = int(preds[i].item())
            prob = float(confs[i].item())
            records.append({
                "path": img_path,
                "true_idx": true_idx,
                "true_name": ckpt_classes[true_idx],
                "pred_idx": pred_idx,
                "pred_name": ckpt_classes[pred_idx],
                "prob": prob,
            })

print(f"共记录样本: {len(records)}")

correct = [r for r in records if r["true_idx"] == r["pred_idx"]]
wrong   = [r for r in records if r["true_idx"] != r["pred_idx"]]

print(f"预测正确: {len(correct)} 个")
print(f"预测错误: {len(wrong)} 个")


推理 valid: 100%|██████████| 42/42 [06:06<00:00,  8.72s/it]

共记录样本: 1341
预测正确: 1149 个
预测错误: 192 个





In [None]:
# === 6: 定义 Captum 可解释性 & 可视化函数 ===

from captum.attr import Saliency, IntegratedGradients
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# 为了结果稳定一点（随机性主要来自训练，XAI这一步几乎无随机）
random.seed(42)
torch.manual_seed(42)

saliency = Saliency(model)
ig = IntegratedGradients(model)


vis_tf = transforms.Compose([
    transforms.Resize(int(IMGSZ * 1.15)),
    transforms.CenterCrop(IMGSZ),
])

def compute_attributions(img_path, target_idx):
    """
    对单张图像计算：
    - Saliency
    - Integrated Gradients
    返回: (vis_img, saliency_map, ig_map)
    其中 vis_img 是已经 Resize+CenterCrop 到 224×224 的 PIL 图，
    saliency_map / ig_map 也是 224×224，对齐不会再只显示一角。
    """
    # 1. 读原图
    orig = Image.open(img_path).convert("RGB")

    # 2. 给模型的输入（带 Normalize）
    input_tensor = eval_tf(orig).unsqueeze(0).to(device)  # (1,3,224,224)

    # 3. 给可视化的版本（只做 Resize+CenterCrop）
    vis_img = vis_tf(orig)  # 也是 224×224 的 PIL.Image

    # 4. 计算 Saliency
    model.zero_grad()
    input_s = input_tensor.clone().detach().requires_grad_(True)
    attr_s = saliency.attribute(input_s, target=target_idx)
    saliency_map = attr_s.abs().max(dim=1)[0].squeeze().detach().cpu().numpy()

    # 5. 计算 Integrated Gradients
    model.zero_grad()
    input_ig = input_tensor.clone().detach().requires_grad_(True)
    baseline = torch.zeros_like(input_ig)
    attr_ig = ig.attribute(input_ig, baselines=baseline, target=target_idx, n_steps=32)
    ig_map = attr_ig.abs().max(dim=1)[0].squeeze().detach().cpu().numpy()

    # 6. 归一化
    def norm_map(x):
        x = x - x.min()
        if x.max() > 0:
            x = x / x.max()
        return x

    saliency_map = norm_map(saliency_map)
    ig_map = norm_map(ig_map)

    return vis_img, saliency_map, ig_map


def plot_triple(record, vis_img, sal_map, ig_map, out_path=None, show=False):
    """
    画三联图：
    - 原始 crop
    - 原图 + Saliency
    - 原图 + Integrated Gradients
    并保存到 out_path（如果给定）。
    """
    gt = record["true_name"]
    pred = record["pred_name"]
    prob = record["prob"]

    fig, axes = plt.subplots(1, 3, figsize=(9, 3))

    axes[0].imshow(vis_img)
    axes[0].set_title(f"Original\nGT: {gt}\nPred: {pred}\nConf: {prob:.2f}")
    axes[0].axis("off")

    axes[1].imshow(vis_img)
    axes[1].imshow(sal_map, cmap="jet", alpha=0.5)
    axes[1].set_title("Saliency")
    axes[1].axis("off")

    axes[2].imshow(vis_img)
    axes[2].imshow(ig_map, cmap="jet", alpha=0.5)
    axes[2].set_title("Integrated Gradients")
    axes[2].axis("off")

    plt.tight_layout()

    if out_path is not None:
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        fig.savefig(out_path, dpi=200, bbox_inches="tight")
        print("已保存:", out_path)

    if show:
        plt.show()
    else:
        plt.close(fig)


In [None]:
# 7: 生成 2–4 组正确样本 + 错误样本的可解释性图像

OUT_DIR = os.path.join(ROOT, "xai_resnet")
os.makedirs(OUT_DIR, exist_ok=True)

N_CORRECT = 20
N_WRONG   = 20


sample_correct = random.sample(correct, min(N_CORRECT, len(correct))) if len(correct) > 0 else []
sample_wrong   = random.sample(wrong,   min(N_WRONG,   len(wrong)))   if len(wrong)   > 0 else []

print(f"用于可视化的正确样本数: {len(sample_correct)}")
print(f"用于可视化的错误样本数: {len(sample_wrong)}")

# corrent sample
for i, rec in enumerate(sample_correct):
    vis_img, sal_map, ig_map = compute_attributions(
        img_path=rec["path"],
        target_idx=rec["pred_idx"],
    )
    fname = f"correct_{i:02d}_{rec['true_name']}_as_{rec['pred_name']}.png"
    out_path = os.path.join(OUT_DIR, "correct", fname)
    plot_triple(rec, vis_img, sal_map, ig_map, out_path=out_path, show=False)

# wrong sample
for i, rec in enumerate(sample_wrong):
    vis_img, sal_map, ig_map = compute_attributions(
        img_path=rec["path"],
        target_idx=rec["pred_idx"],
    )
    fname = f"wrong_{i:02d}_{rec['true_name']}_as_{rec['pred_name']}.png"
    out_path = os.path.join(OUT_DIR, "wrong", fname)
    plot_triple(rec, vis_img, sal_map, ig_map, out_path=out_path, show=False)

print("\n所有可视化已生成，查看目录：")
!find xai_resnet -maxdepth 2 -type f | sort


用于可视化的错误样本数: 50
已保存: /content/9517_data/xai_resnet/wrong/wrong_00_Slugs_as_Grasshoppers.png
已保存: /content/9517_data/xai_resnet/wrong/wrong_01_Caterpillars_as_Ants.png
已保存: /content/9517_data/xai_resnet/wrong/wrong_02_Ants_as_Bees.png
已保存: /content/9517_data/xai_resnet/wrong/wrong_03_Beetles_as_Weevils.png
已保存: /content/9517_data/xai_resnet/wrong/wrong_04_Ants_as_Earwigs.png
已保存: /content/9517_data/xai_resnet/wrong/wrong_05_Caterpillars_as_Wasps.png
已保存: /content/9517_data/xai_resnet/wrong/wrong_06_Earthworms_as_Caterpillars.png
已保存: /content/9517_data/xai_resnet/wrong/wrong_07_Caterpillars_as_Ants.png
已保存: /content/9517_data/xai_resnet/wrong/wrong_08_Ants_as_Grasshoppers.png
已保存: /content/9517_data/xai_resnet/wrong/wrong_09_Beetles_as_Weevils.png
已保存: /content/9517_data/xai_resnet/wrong/wrong_10_Grasshoppers_as_Ants.png
已保存: /content/9517_data/xai_resnet/wrong/wrong_11_Wasps_as_Bees.png
已保存: /content/9517_data/xai_resnet/wrong/wrong_12_Wasps_as_Bees.png
已保存: /content/9517_data/xai_resn

## download

In [None]:
source_dir = "/content/9517_data/xai_resnet/correct"
output_zip = "/content/correct_images.zip"


!zip -r $output_zip $source_dir
print("Finished!")

  adding: content/9517_data/xai_resnet/correct/ (stored 0%)
  adding: content/9517_data/xai_resnet/correct/correct_15_Bees_as_Bees.png (deflated 0%)
  adding: content/9517_data/xai_resnet/correct/correct_01_Caterpillars_as_Caterpillars.png (deflated 1%)
  adding: content/9517_data/xai_resnet/correct/correct_04_Earthworms_as_Earthworms.png (deflated 1%)
  adding: content/9517_data/xai_resnet/correct/correct_12_Earthworms_as_Earthworms.png (deflated 1%)
  adding: content/9517_data/xai_resnet/correct/correct_09_Snails_as_Snails.png (deflated 1%)
  adding: content/9517_data/xai_resnet/correct/correct_11_Beetles_as_Beetles.png (deflated 1%)
  adding: content/9517_data/xai_resnet/correct/correct_00_Ants_as_Ants.png (deflated 0%)
  adding: content/9517_data/xai_resnet/correct/correct_03_Snails_as_Snails.png (deflated 1%)
  adding: content/9517_data/xai_resnet/correct/correct_18_Moths_as_Moths.png (deflated 1%)
  adding: content/9517_data/xai_resnet/correct/correct_10_Moths_as_Moths.png (defla

In [None]:
source_dir_w = "/content/9517_data/xai_resnet/wrong"
output_zip_w = "/content/wrong_images.zip"


!zip -r $output_zip_w $source_dir_w
print("Finished!")

updating: content/9517_data/xai_resnet/wrong/ (stored 0%)
updating: content/9517_data/xai_resnet/wrong/wrong_13_Caterpillars_as_Earthworms.png (deflated 1%)
updating: content/9517_data/xai_resnet/wrong/wrong_14_Grasshoppers_as_Earthworms.png (deflated 1%)
updating: content/9517_data/xai_resnet/wrong/wrong_12_Slugs_as_Earthworms.png (deflated 1%)
updating: content/9517_data/xai_resnet/wrong/wrong_01_Caterpillars_as_Beetles.png (deflated 1%)
updating: content/9517_data/xai_resnet/wrong/wrong_00_Slugs_as_Caterpillars.png (deflated 1%)
updating: content/9517_data/xai_resnet/wrong/wrong_19_Beetles_as_Weevils.png (deflated 1%)
updating: content/9517_data/xai_resnet/wrong/wrong_09_Grasshoppers_as_Ants.png (deflated 1%)
updating: content/9517_data/xai_resnet/wrong/wrong_16_Snails_as_Bees.png (deflated 1%)
updating: content/9517_data/xai_resnet/wrong/wrong_10_Caterpillars_as_Ants.png (deflated 1%)
updating: content/9517_data/xai_resnet/wrong/wrong_17_Ants_as_Bees.png (deflated 1%)
updating: con