In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from scipy.ndimage import gaussian_filter

class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        # 加载预训练的ResNet18
        resnet = torchvision.models.resnet18(pretrained=True)
        
        # 只使用到layer3，保留更大的特征图
        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

def get_mnist_loaders(normal_class=5, batch_size=64):
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.Grayscale(3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
    ])
    
    trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
    
    train_idx = np.where(np.array(trainset.targets) == normal_class)[0]
    train_loader = DataLoader(Subset(trainset, train_idx), 
                            batch_size=batch_size, 
                            shuffle=True)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

def compute_anomaly_map(features, mean_feature):
    """计算异常图"""
    B, C, H, W = features.shape
    # 计算每个空间位置的特征距离
    diff = features - mean_feature.unsqueeze(0)
    distances = torch.norm(diff, dim=1)  # [B, H, W]
    return distances.numpy()

def visualize_anomalies(image, anomaly_map, threshold=0.90):
    """可视化异常区域"""
    # 对异常图进行高斯平滑
    anomaly_map_smooth = gaussian_filter(anomaly_map, sigma=1.0)
    
    # 归一化异常图
    anomaly_map_norm = (anomaly_map_smooth - anomaly_map_smooth.min()) / \
                      (anomaly_map_smooth.max() - anomaly_map_smooth.min() + 1e-8)
    
    # 通过百分比阈值确定异常区域
    threshold_value = np.percentile(anomaly_map_norm, threshold * 100)
    mask = anomaly_map_norm > threshold_value
    
    plt.figure(figsize=(15, 5))
    
    # 原始图像
    plt.subplot(131)
    plt.imshow(image[0], cmap='gray')  # 显示第一个通道
    plt.title('Original Image')
    plt.axis('off')
    
    # 异常分数热力图
    plt.subplot(132)
    plt.imshow(anomaly_map_norm, cmap='hot')
    plt.colorbar()
    plt.title('Anomaly Heatmap')
    plt.axis('off')
    
    # 叠加的可视化
    plt.subplot(133)
    plt.imshow(image[0], cmap='gray')
    overlay = np.ma.masked_where(~mask, anomaly_map_norm)
    plt.imshow(overlay, cmap='hot', alpha=0.6)
    plt.title('Overlay')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

def main():
    NORMAL_CLASS = 5
    BATCH_SIZE = 64
    
    # 初始化特征提取器
    feature_extractor = FeatureExtractor()
    feature_extractor.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    feature_extractor = feature_extractor.to(device)
    
    print("Loading data...")
    train_loader, test_loader = get_mnist_loaders(NORMAL_CLASS, BATCH_SIZE)
    
    # 提取训练集特征
    print("Extracting training features...")
    train_features = []
    with torch.no_grad():
        for images, _ in tqdm(train_loader):
            images = images.to(device)
            features = feature_extractor(images)
            train_features.append(features.cpu())
    train_features = torch.cat(train_features, 0)
    
    # 计算特征均值
    mean_feature = train_features.mean(dim=0)
    
    # 提取测试集特征
    print("Extracting test features...")
    test_features = []
    test_labels = []
    test_images = []
    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            images = images.to(device)
            features = feature_extractor(images)
            test_features.append(features.cpu())
            test_labels.extend(labels.numpy())
            test_images.append(images.cpu())
    
    test_features = torch.cat(test_features, 0)
    test_labels = np.array(test_labels)
    test_images = torch.cat(test_images, 0)
    
    # 计算异常图和整体异常分数
    anomaly_maps = compute_anomaly_map(test_features, mean_feature)
    anomaly_scores = anomaly_maps.mean(axis=(1,2))
    
    # 计算ROC分数
    true_labels = (test_labels != NORMAL_CLASS).astype(int)
    auc_score = roc_auc_score(true_labels, anomaly_scores)
    print(f"AUC-ROC Score: {auc_score:.4f}")
    
    # 可视化一些异常样本
    anomaly_indices = np.where(true_labels == 1)[0]
    top_anomalies = anomaly_indices[np.argsort(-anomaly_scores[anomaly_indices])[:3]]
    
    print("\nVisualizing top anomalies...")
    for idx in top_anomalies:
        print(f"Sample {idx} - True Label: {test_labels[idx]}")
        visualize_anomalies(test_images[idx], anomaly_maps[idx])

if __name__ == "__main__":
    main()

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from scipy.ndimage import gaussian_filter
import os
from PIL import Image

class MVTecDataset(Dataset):
    def __init__(self, root_dir, is_train=True, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.is_train = is_train
        
        # 确定数据目录
        self.img_dir = os.path.join(root_dir, 'train' if is_train else 'test')
        
        self.image_paths = []
        self.labels = []  # 0 for good, 1 for anomaly
        self.mask_paths = []  # 只对测试集有效
        
        if is_train:
            # 训练集只包含good样本
            good_dir = os.path.join(self.img_dir, 'good')
            for img_name in os.listdir(good_dir):
                if img_name.endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(good_dir, img_name))
                    self.labels.append(0)
                    self.mask_paths.append(None)
        else:
            # 测试集包含所有类型
            for defect_type in os.listdir(self.img_dir):
                defect_dir = os.path.join(self.img_dir, defect_type)
                if not os.path.isdir(defect_dir):
                    continue
                    
                for img_name in os.listdir(defect_dir):
                    if img_name.endswith(('.png', '.jpg', '.jpeg')):
                        self.image_paths.append(os.path.join(defect_dir, img_name))
                        self.labels.append(0 if defect_type == 'good' else 1)
                        
                        # 如果是异常样本，找对应的mask
                        if defect_type != 'good':
                            mask_dir = os.path.join(root_dir, 'ground_truth', defect_type)
                            mask_name = img_name.replace('.jpg', '_mask.png').replace('.jpeg', '_mask.png').replace('.png', '_mask.png')
                            mask_path = os.path.join(mask_dir, mask_name)
                            self.mask_paths.append(mask_path if os.path.exists(mask_path) else None)
                        else:
                            self.mask_paths.append(None)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        label = self.labels[idx]
        
        if self.is_train:
            return image, label
            
        # 对于测试集，还要返回mask路径用于评估
        mask_path = self.mask_paths[idx]
        return image, label, mask_path

class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        # 加载预训练的ResNet18
        resnet = torchvision.models.resnet18(pretrained=True)
        
        # 只使用到layer3，保留更大的特征图
        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

def compute_anomaly_map(features, mean_feature):
    """计算异常图"""
    B, C, H, W = features.shape
    # 计算每个空间位置的特征距离
    diff = features - mean_feature.unsqueeze(0)
    distances = torch.norm(diff, dim=1)  # [B, H, W]
    return distances.numpy()

def visualize_anomalies(image, anomaly_map, mask_path=None, threshold=0.90):
    """可视化异常区域"""
    # 对异常图进行高斯平滑
    anomaly_map_smooth = gaussian_filter(anomaly_map, sigma=1.0)
    
    # 归一化异常图
    anomaly_map_norm = (anomaly_map_smooth - anomaly_map_smooth.min()) / \
                      (anomaly_map_smooth.max() - anomaly_map_smooth.min() + 1e-8)
    
    # 通过百分比阈值确定异常区域
    threshold_value = np.percentile(anomaly_map_norm, threshold * 100)
    pred_mask = anomaly_map_norm > threshold_value
    
    num_subplots = 4 if mask_path else 3
    plt.figure(figsize=(5 * num_subplots, 5))
    
    # 反归一化图像
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img_denorm = image * std + mean
    
    # 确保值在[0,1]范围内
    img_denorm = torch.clamp(img_denorm, 0, 1)
    
    # 原始图像
    plt.subplot(1, num_subplots, 1)
    plt.imshow(img_denorm.permute(1, 2, 0))  # CHW -> HWC
    plt.title('Original Image')
    plt.axis('off')
    
    # 异常分数热力图
    plt.subplot(1, num_subplots, 2)
    plt.imshow(anomaly_map_norm, cmap='hot')
    plt.colorbar()
    plt.title('Anomaly Heatmap')
    plt.axis('off')
    
    # 叠加的可视化
    plt.subplot(1, num_subplots, 3)
    plt.imshow(img_denorm.permute(1, 2, 0))
    overlay = np.ma.masked_where(~pred_mask, anomaly_map_norm)
    plt.imshow(overlay, cmap='hot', alpha=0.6)
    plt.title('Overlay')
    plt.axis('off')
    
    # 如果有ground truth mask，显示对比
    if mask_path and os.path.exists(mask_path):
        plt.subplot(1, num_subplots, 4)
        gt_mask = Image.open(mask_path).convert('L')
        gt_mask = transforms.Resize(pred_mask.shape)(transforms.ToTensor()(gt_mask))
        plt.imshow(gt_mask[0], cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

def main(data_root):
    BATCH_SIZE = 32
    
    # 数据预处理
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    # 初始化数据集和加载器
    train_dataset = MVTecDataset(data_root, is_train=True, transform=transform)
    test_dataset = MVTecDataset(data_root, is_train=False, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    def custom_collate(batch):
        images = torch.stack([item[0] for item in batch])
        labels = torch.tensor([item[1] for item in batch])
        mask_paths = [item[2] for item in batch]  # 保持mask_paths为列表
        return images, labels, mask_paths
        
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=custom_collate)
    
    # 初始化特征提取器
    feature_extractor = FeatureExtractor()
    feature_extractor.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    feature_extractor = feature_extractor.to(device)
    
    print("Extracting training features...")
    train_features = []
    with torch.no_grad():
        for images, _ in tqdm(train_loader):
            images = images.to(device)
            features = feature_extractor(images)
            train_features.append(features.cpu())
    train_features = torch.cat(train_features, 0)
    
    # 计算特征均值
    mean_feature = train_features.mean(dim=0)
    
    # 测试阶段
    print("Extracting test features...")
    test_features = []
    test_labels = []
    test_images = []
    test_mask_paths = []
    
    with torch.no_grad():
        for images, labels, mask_paths in tqdm(test_loader):
            images = images.to(device)
            features = feature_extractor(images)
            test_features.append(features.cpu())
            test_labels.extend(labels.cpu().numpy())
            test_images.append(images.cpu())
            test_mask_paths.extend(mask_paths)
    
    test_features = torch.cat(test_features, 0)
    test_labels = np.array(test_labels)
    test_images = torch.cat(test_images, 0)
    
    # 计算异常图和整体异常分数
    anomaly_maps = compute_anomaly_map(test_features, mean_feature)
    anomaly_scores = anomaly_maps.mean(axis=(1,2))
    
    # 计算ROC分数
    auc_score = roc_auc_score(test_labels, anomaly_scores)
    print(f"AUC-ROC Score: {auc_score:.4f}")
    
    # 可视化一些异常样本
    anomaly_indices = np.where(test_labels == 1)[0]
    if len(anomaly_indices) > 0:
        top_anomalies = anomaly_indices[np.argsort(-anomaly_scores[anomaly_indices])[:3]]
        
        print("\nVisualizing top anomalies...")
        for idx in top_anomalies:
            print(f"Anomaly score: {anomaly_scores[idx]:.4f}")
            visualize_anomalies(test_images[idx], anomaly_maps[idx], test_mask_paths[idx])

if __name__ == "__main__":
    # 指定数据集路径，例如 "D:/PROJECTS/MVTEC-AD/bottle"
    data_root = "D:/Projects/MVTEC-AD/bottle"
    main(data_root)

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from scipy.ndimage import gaussian_filter
import os
from PIL import Image

class MVTecDataset(Dataset):
    def __init__(self, root_dir, is_train=True, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.is_train = is_train
        
        # 确定数据目录
        self.img_dir = os.path.join(root_dir, 'train' if is_train else 'test')
        
        self.image_paths = []
        self.labels = []  # 0 for good, 1 for anomaly
        self.mask_paths = []  # 只对测试集有效
        
        if is_train:
            # 训练集只包含good样本
            good_dir = os.path.join(self.img_dir, 'good')
            for img_name in os.listdir(good_dir):
                if img_name.endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(good_dir, img_name))
                    self.labels.append(0)
                    self.mask_paths.append(None)
        else:
            # 测试集包含所有类型
            for defect_type in os.listdir(self.img_dir):
                defect_dir = os.path.join(self.img_dir, defect_type)
                if not os.path.isdir(defect_dir):
                    continue
                    
                for img_name in os.listdir(defect_dir):
                    if img_name.endswith(('.png', '.jpg', '.jpeg')):
                        self.image_paths.append(os.path.join(defect_dir, img_name))
                        self.labels.append(0 if defect_type == 'good' else 1)
                        
                        # 如果是异常样本，找对应的mask
                        if defect_type != 'good':
                            mask_dir = os.path.join(root_dir, 'ground_truth', defect_type)
                            mask_name = img_name.replace('.jpg', '_mask.png').replace('.jpeg', '_mask.png').replace('.png', '_mask.png')
                            mask_path = os.path.join(mask_dir, mask_name)
                            self.mask_paths.append(mask_path if os.path.exists(mask_path) else None)
                        else:
                            self.mask_paths.append(None)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        label = self.labels[idx]
        
        if self.is_train:
            return image, label
            
        # 对于测试集，还要返回mask路径用于评估
        mask_path = self.mask_paths[idx]
        return image, label, mask_path

class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        # 加载预训练的ResNet18
        resnet = torchvision.models.resnet18(pretrained=True)
        
        # 只使用到layer3，保留更大的特征图
        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

def compute_anomaly_map(features, mean_feature):
    """计算异常图"""
    B, C, H, W = features.shape
    # 计算每个空间位置的特征距离
    diff = features - mean_feature.unsqueeze(0)
    distances = torch.norm(diff, dim=1)  # [B, H, W]
    
    # 使用双线性插值上采样到更高分辨率
    up = torch.nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False)
    distances_upsampled = up(distances.unsqueeze(1)).squeeze(1)
    
    return distances_upsampled.numpy()

def visualize_anomalies(image, anomaly_map, mask_path=None, threshold=0.90):
    """可视化异常区域"""
    # 对异常图进行高斯平滑
    anomaly_map_smooth = gaussian_filter(anomaly_map, sigma=1.0)
    
    # 归一化异常图
    anomaly_map_norm = (anomaly_map_smooth - anomaly_map_smooth.min()) / \
                      (anomaly_map_smooth.max() - anomaly_map_smooth.min() + 1e-8)
    
    # 通过百分比阈值确定异常区域
    threshold_value = np.percentile(anomaly_map_norm, threshold * 100)
    pred_mask = anomaly_map_norm > threshold_value
    
    num_subplots = 4 if mask_path else 3
    plt.figure(figsize=(5 * num_subplots, 5))
    
    # 反归一化图像
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img_denorm = image * std + mean
    
    # 确保值在[0,1]范围内
    img_denorm = torch.clamp(img_denorm, 0, 1)
    
    # 原始图像
    plt.subplot(1, num_subplots, 1)
    plt.imshow(img_denorm.permute(1, 2, 0))  # CHW -> HWC
    plt.title('Original Image')
    plt.axis('off')
    
    # 异常分数热力图
    plt.subplot(1, num_subplots, 2)
    plt.imshow(anomaly_map_norm, cmap='hot')
    plt.colorbar()
    plt.title('Anomaly Heatmap')
    plt.axis('off')
    
    # 叠加的可视化
    plt.subplot(1, num_subplots, 3)
    plt.imshow(img_denorm.permute(1, 2, 0))
    overlay = np.ma.masked_where(~pred_mask, anomaly_map_norm)
    plt.imshow(overlay, cmap='hot', alpha=0.6)
    plt.title('Overlay')
    plt.axis('off')
    
    # 如果有ground truth mask，显示对比
    if mask_path and os.path.exists(mask_path):
        plt.subplot(1, num_subplots, 4)
        gt_mask = Image.open(mask_path).convert('L')
        gt_mask = transforms.Resize(pred_mask.shape)(transforms.ToTensor()(gt_mask))
        plt.imshow(gt_mask[0], cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

def main(data_root):
    BATCH_SIZE = 32
    
    # 数据预处理
    transform = transforms.Compose([
        # transforms.Resize((448, 448)),  # 增大输入分辨率
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    # 初始化数据集和加载器
    train_dataset = MVTecDataset(data_root, is_train=True, transform=transform)
    test_dataset = MVTecDataset(data_root, is_train=False, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    def custom_collate(batch):
        images = torch.stack([item[0] for item in batch])
        labels = torch.tensor([item[1] for item in batch])
        mask_paths = [item[2] for item in batch]  # 保持mask_paths为列表
        return images, labels, mask_paths
        
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=custom_collate)
    
    # 初始化特征提取器
    feature_extractor = FeatureExtractor()
    feature_extractor.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    feature_extractor = feature_extractor.to(device)
    
    print("Extracting training features...")
    train_features = []
    with torch.no_grad():
        for images, _ in tqdm(train_loader):
            images = images.to(device)
            features = feature_extractor(images)
            train_features.append(features.cpu())
    train_features = torch.cat(train_features, 0)
    
    # 计算特征均值
    mean_feature = train_features.mean(dim=0)
    
    # 测试阶段
    print("Extracting test features...")
    test_features = []
    test_labels = []
    test_images = []
    test_mask_paths = []
    
    with torch.no_grad():
        for images, labels, mask_paths in tqdm(test_loader):
            images = images.to(device)
            features = feature_extractor(images)
            test_features.append(features.cpu())
            test_labels.extend(labels.cpu().numpy())
            test_images.append(images.cpu())
            test_mask_paths.extend(mask_paths)
    
    test_features = torch.cat(test_features, 0)
    test_labels = np.array(test_labels)
    test_images = torch.cat(test_images, 0)
    
    # 计算异常图和整体异常分数
    anomaly_maps = compute_anomaly_map(test_features, mean_feature)
    anomaly_scores = anomaly_maps.mean(axis=(1,2))
    
    # 计算ROC分数
    auc_score = roc_auc_score(test_labels, anomaly_scores)
    print(f"AUC-ROC Score: {auc_score:.4f}")
    
    # 可视化一些异常样本
    anomaly_indices = np.where(test_labels == 1)[0]
    if len(anomaly_indices) > 0:
        top_anomalies = anomaly_indices[np.argsort(-anomaly_scores[anomaly_indices])[:3]]
        
        print("\nVisualizing top anomalies...")
        for idx in top_anomalies:
            print(f"Anomaly score: {anomaly_scores[idx]:.4f}")
            visualize_anomalies(test_images[idx], anomaly_maps[idx], test_mask_paths[idx])

if __name__ == "__main__":
    # 指定数据集路径，例如 "D:/PROJECTS/MVTEC-AD/bottle"
    data_root = "D:/Projects/MVTEC-AD/bottle"
    main(data_root)

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from scipy.ndimage import gaussian_filter
import os
from PIL import Image
from typing import Optional, Tuple, List, Dict, Union
import json

class MVTecDataset(Dataset):
    """MVTec异常检测数据集类"""
    
    def __init__(self, root_dir: str, is_train: bool = True, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.is_train = is_train
        self.img_dir = os.path.join(root_dir, 'train' if is_train else 'test')
        
        self.image_paths = []
        self.labels = []
        self.mask_paths = []
        
        self._load_dataset()
    
    def _load_dataset(self):
        """加载数据集"""
        if self.is_train:
            self._load_train_data()
        else:
            self._load_test_data()
    
    def _load_train_data(self):
        """加载训练数据（仅包含正常样本）"""
        good_dir = os.path.join(self.img_dir, 'good')
        for img_name in os.listdir(good_dir):
            if img_name.endswith(('.png', '.jpg', '.jpeg')):
                self.image_paths.append(os.path.join(good_dir, img_name))
                self.labels.append(0)
                self.mask_paths.append(None)
    
    def _load_test_data(self):
        """加载测试数据（包含正常和异常样本）"""
        for defect_type in os.listdir(self.img_dir):
            defect_dir = os.path.join(self.img_dir, defect_type)
            if not os.path.isdir(defect_dir):
                continue
                
            for img_name in os.listdir(defect_dir):
                if img_name.endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(defect_dir, img_name))
                    self.labels.append(0 if defect_type == 'good' else 1)
                    
                    if defect_type != 'good':
                        mask_dir = os.path.join(self.root_dir, 'ground_truth', defect_type)
                        mask_name = img_name.replace('.jpg', '_mask.png').replace('.jpeg', '_mask.png').replace('.png', '_mask.png')
                        mask_path = os.path.join(mask_dir, mask_name)
                        self.mask_paths.append(mask_path if os.path.exists(mask_path) else None)
                    else:
                        self.mask_paths.append(None)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        label = self.labels[idx]
        
        if self.is_train:
            return image, label
        return image, label, self.mask_paths[idx]

# class FeatureExtractor(nn.Module):
#     """特征提取器类，基于预训练的ResNet18"""
    
#     def __init__(self):
#         super().__init__()
#         resnet = torchvision.models.resnet18(pretrained=True)
        
#         self.conv1 = resnet.conv1
#         self.bn1 = resnet.bn1
#         self.relu = resnet.relu
#         self.maxpool = resnet.maxpool
#         self.layer1 = resnet.layer1
#         self.layer2 = resnet.layer2
#         self.layer3 = resnet.layer3
        
#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.bn1(x)
#         x = self.relu(x)
#         x = self.maxpool(x)
#         x = self.layer1(x)
#         x = self.layer2(x)
#         x = self.layer3(x)
#         return x
    
class FeatureExtractor(nn.Module):
    """特征提取器类，基于预训练的ResNet18，使用多层特征融合"""
    
    def __init__(self, num_channels_per_layer: int = 64, random_seed: int = 42):
        """
        初始化特征提取器
        
        Args:
            num_channels_per_layer: 每层采样的通道数
            random_seed: 随机种子，用于确保通道采样的一致性
        """
        super().__init__()
        # 设置随机种子以确保通道采样的一致性
        torch.manual_seed(random_seed)
        np.random.seed(random_seed)
        
        # 加载预训练的ResNet18
        resnet = torchvision.models.resnet18(pretrained=True)
        
        # 提取需要的层
        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1  # 56x56
        self.layer2 = resnet.layer2  # 28x28
        self.layer3 = resnet.layer3  # 14x14
        self.layer4 = resnet.layer4  # 7x7
        
        # 获取每层的输出通道数
        layer1_channels = self.layer1[-1].conv2.out_channels  # 获取layer1最后一个block的输出通道数
        layer2_channels = self.layer2[-1].conv2.out_channels
        layer3_channels = self.layer3[-1].conv2.out_channels
        layer4_channels = self.layer4[-1].conv2.out_channels
        
        # 存储每层要采样的通道索引
        self.num_channels = num_channels_per_layer
        self.layer1_indices = torch.randperm(layer1_channels)[:num_channels_per_layer]
        self.layer2_indices = torch.randperm(layer2_channels)[:num_channels_per_layer]
        self.layer3_indices = torch.randperm(layer3_channels)[:num_channels_per_layer]
        self.layer4_indices = torch.randperm(layer4_channels)[:num_channels_per_layer]
    
    def _sample_channels(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
        """从特征图中采样指定的通道"""
        return x[:, indices, :, :]
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        前向传播，提取并融合多层特征
        
        Args:
            x: 输入图像张量

        Returns:
            融合后的特征图，尺寸为56x56
        """
        # 基础特征提取
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        # 提取各层特征
        feat1 = self.layer1(x)          # 56x56
        feat2 = self.layer2(feat1)      # 28x28
        feat3 = self.layer3(feat2)      # 14x14
        feat4 = self.layer4(feat3)      # 7x7
        
        # 对每层特征进行通道采样
        feat1_sampled = self._sample_channels(feat1, self.layer1_indices)  # 保持56x56
        feat2_sampled = self._sample_channels(feat2, self.layer2_indices)  
        feat3_sampled = self._sample_channels(feat3, self.layer3_indices)
        feat4_sampled = self._sample_channels(feat4, self.layer4_indices)
        
        # 创建上采样层，统一上采样到56x56
        # 获取目标特征图大小
        target_size = feat1.shape[-2:]  # 获取H,W
        up = nn.Upsample(size=target_size, mode='bilinear', align_corners=False)
        
        # 上采样到56x56
        feat2_up = up(feat2_sampled)    # 28x28 -> 56x56
        feat3_up = up(feat3_sampled)    # 14x14 -> 56x56
        feat4_up = up(feat4_sampled)    # 7x7 -> 56x56
        
        # 特征融合 (简单相加)
        fused_features = feat1_sampled + feat2_up + feat3_up + feat4_up
        
        return fused_features
    
class AnomalyDetector:
    """异常检测器类"""
    
    def __init__(self, model_path: Optional[str] = None, input_size: Tuple[int, int] = (224, 224)):
        self.input_size = input_size
        self.transform = transforms.Compose([
            transforms.Resize(self.input_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.feature_extractor = FeatureExtractor().to(self.device)
        self.feature_extractor.eval()
        
        self.mean_feature = None
        if model_path:
            self.load_model(model_path)
    
    def train(self, data_root: str, batch_size: int = 32) -> Dict[str, float]:
        """训练模型（提取正常样本特征统计信息）"""
        train_dataset = MVTecDataset(data_root, is_train=True, transform=self.transform)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        
        print("Extracting training features...")
        train_features = []
        with torch.no_grad():
            for images, _ in tqdm(train_loader):
                images = images.to(self.device)
                features = self.feature_extractor(images)
                train_features.append(features.cpu())
        
        train_features = torch.cat(train_features, 0)
        self.mean_feature = train_features.mean(dim=0)
        
        return {"status": "success", "num_samples": len(train_dataset)}
    
    def save_model(self, save_path: str):
        """保存模型（特征统计信息）"""
        if self.mean_feature is None:
            raise ValueError("Model not trained yet!")
        
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(self.mean_feature, save_path)
    
    def load_model(self, model_path: str):
        """加载模型"""
        self.mean_feature = torch.load(model_path)
    
    def predict_single_image(self, image_path: str, threshold: float = 0.90) -> Dict[str, Union[float, np.ndarray]]:
        """对单张图片进行异常检测"""
        if self.mean_feature is None:
            raise ValueError("Model not trained yet!")
        
        # 加载和预处理图片
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.transform(image).unsqueeze(0)
        
        import time
        # 提取特征并计算异常图
        with torch.no_grad():
            time1 = time.time()
            features = self.feature_extractor(image_tensor.to(self.device))
            time2 = time.time()
            anomaly_map = self._compute_anomaly_map(features.cpu(), self.mean_feature)
            time3 = time.time()

        print(f"Feature extraction time: {time2 - time1:.4f}s")
        print(f"Anomaly map computation time: {time3 - time2:.4f}s")
        
        # 计算异常分数
        anomaly_score = float(anomaly_map.mean())
        
        # 生成预测掩码
        anomaly_map_smooth = gaussian_filter(anomaly_map[0], sigma=1.0)
        anomaly_map_norm = (anomaly_map_smooth - anomaly_map_smooth.min()) / \
                          (anomaly_map_smooth.max() - anomaly_map_smooth.min() + 1e-8)
        threshold_value = np.percentile(anomaly_map_norm, threshold * 100)
        pred_mask = anomaly_map_norm > threshold_value
        
        return {
            "anomaly_score": anomaly_score,
            "anomaly_map": anomaly_map_norm,
            "pred_mask": pred_mask,
            "image_tensor": image_tensor
        }
    
    def evaluate(self, data_root: str, batch_size: int = 32) -> Dict[str, float]:
        """评估模型性能"""
        if self.mean_feature is None:
            raise ValueError("Model not trained yet!")
        
        test_dataset = MVTecDataset(data_root, is_train=False, transform=self.transform)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                               collate_fn=self._custom_collate)
        
        print("Evaluating model...")
        test_features = []
        test_labels = []
        
        with torch.no_grad():
            for images, labels, _ in tqdm(test_loader):
                images = images.to(self.device)
                features = self.feature_extractor(images)
                test_features.append(features.cpu())
                test_labels.extend(labels.numpy())
        
        test_features = torch.cat(test_features, 0)
        test_labels = np.array(test_labels)
        
        # 计算异常图和分数
        anomaly_maps = self._compute_anomaly_map(test_features, self.mean_feature)
        anomaly_scores = anomaly_maps.mean(axis=(1,2))
        
        # 计算ROC分数
        auc_score = roc_auc_score(test_labels, anomaly_scores)
        
        return {
            "auc_score": float(auc_score),
            "num_samples": len(test_dataset)
        }
    
    def visualize_prediction(self, result: Dict[str, Union[float, np.ndarray]], 
                           save_path: Optional[str] = None,
                           show_gt_mask: bool = True):
        """可视化检测结果"""
        image_tensor = result["image_tensor"][0]  # 取出CHW格式的图像张量
        anomaly_map = result["anomaly_map"]       # 已归一化的异常图
        pred_mask = result["pred_mask"]           # 预测的异常掩码
        
        num_subplots = 3
        plt.figure(figsize=(5 * num_subplots, 5))
        
        # 反归一化图像
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        img_denorm = image_tensor * std + mean
        img_denorm = torch.clamp(img_denorm, 0, 1)
        
        # 1. 显示原始图像
        plt.subplot(1, num_subplots, 1)
        plt.imshow(img_denorm.permute(1, 2, 0))  # CHW -> HWC
        plt.title('Original Image')
        plt.axis('off')
        
        # 2. 显示异常热力图
        plt.subplot(1, num_subplots, 2)
        plt.imshow(anomaly_map, cmap='hot')
        plt.colorbar()
        plt.title(f'Anomaly Map\nScore: {result["anomaly_score"]:.4f}')
        plt.axis('off')
        
        # 3. 显示叠加效果
        plt.subplot(1, num_subplots, 3)
        plt.imshow(img_denorm.permute(1, 2, 0))
        # 使用masked_array来创建带透明度的overlay
        overlay = np.ma.masked_where(~pred_mask, anomaly_map)
        plt.imshow(overlay, cmap='hot', alpha=0.6)
        plt.title('Overlay')
        plt.axis('off')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path)
            plt.close()
        else:
            plt.show()
    

    def _compute_anomaly_map(self, features: torch.Tensor, mean_feature: torch.Tensor) -> np.ndarray:
        """计算异常图"""
        # 计算每个位置的特征距离
        diff = features - mean_feature.unsqueeze(0)
        distances = torch.norm(diff, dim=1)  # [B, H, W]
        
        # 使用双线性插值上采样到224x224
        up = torch.nn.Upsample(size=self.input_size, mode='bilinear', align_corners=False)
        distances_upsampled = up(distances.unsqueeze(1)).squeeze(1)  # [B, 224, 224]
        
        return distances_upsampled.numpy()
    
    @staticmethod
    def _custom_collate(batch):
        """自定义数据批处理函数"""
        images = torch.stack([item[0] for item in batch])
        labels = torch.tensor([item[1] for item in batch])
        mask_paths = [item[2] for item in batch]
        return images, labels, mask_paths

In [None]:
# 初始化检测器
detector = AnomalyDetector(input_size=(448, 448))

# 训练模型
detector.train("D:\\Projects\\MVTec-AD\\bottle")

# 保存模型
detector.save_model("D:\\Projects\\ofa-nas-fromework-profile\\ad-model.pth")

# 加载模型
detector = AnomalyDetector("D:\\Projects\\ofa-nas-fromework-profile\\ad-model.pth", input_size=(448, 448))


In [None]:
# 预测单张图片
result = detector.predict_single_image("D:\\Projects\\MVTec-AD\\bottle\\test\\broken_small\\010.png")

# 可视化结果
detector.visualize_prediction(result)

# # 评估模型
# metrics = detector.evaluate("D:\\Projects\\MVTec-AD\\bottle")
# print(f"AUC-ROC Score: {metrics['auc_score']:.4f}")