# Zero-Shot Anomaly Detection using DINOv3

## Imports and Setup

In [1]:
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from pathlib import Path
from typing import List, Tuple, Optional, Dict
import faiss
from dataclasses import dataclass
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import matplotlib
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## ImagePreprocessor

In [2]:
class ImagePreprocessor:
    def __init__(self, input_size: int = 224, batch_size: int = 32):
        self.input_size = input_size
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.Resize(int(input_size * 1.14)),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    def load_image(self, image_path: str) -> Image.Image:
        if isinstance(image_path, str):
            return Image.open(image_path).convert('RGB')
        return image_path
    def preprocess_single(self, image: Image.Image) -> torch.Tensor:
        return self.transform(image).unsqueeze(0)
    def preprocess_batch(self, images: List[Image.Image]) -> torch.Tensor:
        tensors = [self.transform(img) for img in images]
        return torch.stack(tensors)
    def load_images_from_directory(self, directory: str, extensions: Tuple[str, ...] = ('.png', '.jpg', '.jpeg')) -> List[str]:
        directory_path = Path(directory)
        image_paths = []
        for ext in extensions:
            image_paths.extend(directory_path.glob(f"*{ext}"))
        return sorted([str(p) for p in image_paths])
    def create_batches(self, image_paths: List[str]) -> List[Tuple[torch.Tensor, List[str]]]:
        batches = []
        for i in range(0, len(image_paths), self.batch_size):
            batch_paths = image_paths[i:i + self.batch_size]
            batch_images = [self.load_image(path) for path in batch_paths]
            batch_tensor = self.preprocess_batch(batch_images)
            batches.append((batch_tensor, batch_paths))
        return batches
print("✓ ImagePreprocessor class defined")

✓ ImagePreprocessor class defined


## FeatureExtractor

In [3]:
class FeatureExtractor:
    def __init__(self, model, device: torch.device, l2_normalize: bool = True):
        self.model = model.to(device)
        self.model.eval()
        self.device = device
        self.l2_normalize = l2_normalize
        self.patch_size = 16
    @torch.no_grad()
    def extract_features(self, images: torch.Tensor, return_cls: bool = False) -> torch.Tensor:
        images = images.to(self.device)
        if hasattr(self.model, 'forward_features'):
            features = self.model.forward_features(images)
            if isinstance(features, dict):
                if 'x_norm_patchtokens' in features:
                    patch_features = features['x_norm_patchtokens']
                    cls_features = features.get('x_norm_clstoken', None)
                elif 'x_prenorm' in features:
                    all_tokens = features['x_prenorm']
                    cls_features = all_tokens[:, 0, :]
                    patch_features = all_tokens[:, 1:, :]
                else:
                    raise ValueError(f"Unexpected dict keys: {features.keys()}")
            else:
                if len(features.shape) == 3:
                    cls_features = features[:, 0, :]
                    patch_features = features[:, 1:, :]
                elif len(features.shape) == 2:
                    features = self.model.get_intermediate_layers(images, n=1, return_class_token=True)
                    if isinstance(features, (list, tuple)):
                        features = features[0]
                    cls_features = features[0] if isinstance(features, (list, tuple)) else features[:, 0, :]
                    patch_features = features[1] if isinstance(features, (list, tuple)) else features[:, 1:, :]
                else:
                    raise ValueError(f"Unexpected feature shape: {features.shape}")
        else:
            features = self.model(images)
            if isinstance(features, dict):
                if 'x_norm_patchtokens' in features:
                    patch_features = features['x_norm_patchtokens']
                    cls_features = features.get('x_norm_clstoken', None)
                else:
                    raise ValueError(f"Cannot extract patches from dict: {features.keys()}")
            elif len(features.shape) == 2:
                patch_features = self.model.get_intermediate_layers(images, n=1, return_class_token=False)[0]
                cls_features = features
            else:
                cls_features = features[:, 0, :]
                patch_features = features[:, 1:, :]
        if self.l2_normalize:
            patch_features = F.normalize(patch_features, p=2, dim=-1)
            if return_cls and cls_features is not None:
                cls_features = F.normalize(cls_features, p=2, dim=-1)
        if return_cls:
            return patch_features, cls_features
        return patch_features
    def get_feature_dim(self) -> int:
        if hasattr(self.model, 'embed_dim'):
            return self.model.embed_dim
        elif hasattr(self.model, 'num_features'):
            return self.model.num_features
        elif hasattr(self.model, 'backbone'):
            if hasattr(self.model.backbone, 'embed_dim'):
                return self.model.backbone.embed_dim
        return 384
    def get_patch_grid_size(self, input_size: int = 224) -> Tuple[int, int]:
        grid_h = grid_w = input_size // self.patch_size
        return (grid_h, grid_w)
    def extract_from_multiple_batches(self, batches: List[Tuple[torch.Tensor, List[str]]]) -> Dict[str, np.ndarray]:
        feature_dict = {}
        for batch_tensor, batch_paths in tqdm(batches):
            patch_features = self.extract_features(batch_tensor)
            patch_features_np = patch_features.cpu().numpy()
            for i, path in enumerate(batch_paths):
                feature_dict[path] = patch_features_np[i]
        return feature_dict

## NormalFeatureBank

In [4]:
class NormalFeatureBank:
    def __init__(self, feature_dim: int, use_gpu: bool = True, index_type: str = 'flat'):
        self.feature_dim = feature_dim
        self.index_type = index_type
        self.use_gpu = use_gpu and torch.cuda.is_available() and hasattr(faiss, 'StandardGpuResources')
        self.index = self._create_index()
        self.mean = None
        self.cov_inv = None
        self.features_array = None
    def _create_index(self) -> faiss.Index:
        if self.index_type == 'flat':
            index = faiss.IndexFlatL2(self.feature_dim)
        elif self.index_type == 'ivf':
            quantizer = faiss.IndexFlatL2(self.feature_dim)
            index = faiss.IndexIVFFlat(quantizer, self.feature_dim, 100)
        else:
            raise ValueError(f"Unknown index type: {self.index_type}")
        if self.use_gpu:
            try:
                res = faiss.StandardGpuResources()
                index = faiss.index_cpu_to_gpu(res, 0, index)
            except Exception:
                self.use_gpu = False
        return index
    def build_from_features(self, feature_dict: Dict[str, np.ndarray], compute_statistics: bool = True):
        all_features = []
        for features in feature_dict.values():
            all_features.append(features)
        self.features_array = np.vstack(all_features).astype('float32')
        if self.index_type == 'ivf' and not self.index.is_trained:
            self.index.train(self.features_array)
        self.index.add(self.features_array)
        if compute_statistics:
            self._compute_statistics()
    def _compute_statistics(self):
        self.mean = np.mean(self.features_array, axis=0)
        centered = self.features_array - self.mean
        cov = np.cov(centered.T)
        reg = 1e-5
        cov = cov + reg * np.eye(self.feature_dim)
        try:
            self.cov_inv = np.linalg.inv(cov)
        except np.linalg.LinAlgError:
            self.cov_inv = np.linalg.pinv(cov)
    def search_knn(self, query_features: np.ndarray, k: int = 1) -> Tuple[np.ndarray, np.ndarray]:
        query_features = query_features.astype('float32')
        distances, indices = self.index.search(query_features, k)
        return distances, indices
    def compute_mahalanobis_distance(self, query_features: np.ndarray) -> np.ndarray:
        if self.mean is None or self.cov_inv is None:
            raise ValueError("Statistics not computed.")
        centered = query_features - self.mean
        mahal = np.sum(centered @ self.cov_inv * centered, axis=1)
        return np.sqrt(mahal)
    def save(self, path: str):
        save_dict = {
            'features_array': self.features_array,
            'mean': self.mean,
            'cov_inv': self.cov_inv,
            'feature_dim': self.feature_dim,
            'index_type': self.index_type
        }
        index_path = path + '.index'
        if self.use_gpu:
            try:
                cpu_index = faiss.index_gpu_to_cpu(self.index)
                faiss.write_index(cpu_index, index_path)
            except:
                faiss.write_index(self.index, index_path)
        else:
            faiss.write_index(self.index, index_path)
        np.savez(path + '.npz', **save_dict)
    def load(self, path: str):
        data = np.load(path + '.npz', allow_pickle=True)
        self.features_array = data['features_array']
        self.mean = data['mean']
        self.cov_inv = data['cov_inv']
        self.feature_dim = int(data['feature_dim'])
        self.index_type = str(data['index_type'])
        index_path = path + '.index'
        self.index = faiss.read_index(index_path)
        if self.use_gpu:
            try:
                res = faiss.StandardGpuResources()
                self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
            except Exception:
                self.use_gpu = False

## AnomalyScorer

In [5]:
@dataclass
class AnomalyResult:
    image_score: float
    patch_scores: np.ndarray
    anomaly_map: np.ndarray
    is_anomaly: bool
    method: str
class AnomalyScorer:
    def __init__(self, feature_bank: NormalFeatureBank, patch_grid_size: Tuple[int, int], method: str = 'knn', k: int = 1, threshold: Optional[float] = None):
        self.feature_bank = feature_bank
        self.patch_grid_size = patch_grid_size
        self.method = method
        self.k = k
        self.threshold = threshold
    def compute_patch_scores(self, patch_features: np.ndarray) -> np.ndarray:
        if self.method == 'knn':
            distances, _ = self.feature_bank.search_knn(patch_features, k=self.k)
            scores = np.mean(distances, axis=1)
        elif self.method == 'mahalanobis':
            scores = self.feature_bank.compute_mahalanobis_distance(patch_features)
        else:
            raise ValueError(f"Unknown method: {self.method}")
        return scores
    def create_anomaly_map(self, patch_scores: np.ndarray, target_size: Tuple[int, int] = (224, 224)) -> np.ndarray:
        H, W = self.patch_grid_size
        score_map = patch_scores.reshape(H, W)
        anomaly_map = cv2.resize(score_map, target_size, interpolation=cv2.INTER_CUBIC)
        return anomaly_map
    def compute_image_score(self, patch_scores: np.ndarray, aggregation: str = 'max') -> float:
        if aggregation == 'max':
            return float(np.max(patch_scores))
        elif aggregation == 'mean':
            return float(np.mean(patch_scores))
        elif aggregation == 'percentile_95':
            return float(np.percentile(patch_scores, 95))
        else:
            raise ValueError(f"Unknown aggregation: {aggregation}")
    def score_image(self, patch_features: np.ndarray, target_size: Tuple[int, int] = (224, 224), aggregation: str = 'max') -> AnomalyResult:
        patch_scores = self.compute_patch_scores(patch_features)
        image_score = self.compute_image_score(patch_scores, aggregation)
        anomaly_map = self.create_anomaly_map(patch_scores, target_size)
        is_anomaly = False
        if self.threshold is not None:
            is_anomaly = image_score > self.threshold
        return AnomalyResult(image_score=image_score, patch_scores=patch_scores, anomaly_map=anomaly_map, is_anomaly=is_anomaly, method=self.method)
    def set_threshold_from_normal_scores(self, normal_image_scores: List[float], percentile: float = 99.0):
        self.threshold = np.percentile(normal_image_scores, percentile)
    def visualize_result(self, image: Image.Image, result: AnomalyResult, figsize: Tuple[int, int] = (15, 5)):
        fig, axes = plt.subplots(1, 3, figsize=figsize)
        axes[0].imshow(image)
        axes[1].imshow(result.anomaly_map, cmap='jet')
        plt.colorbar(axes[1].images[0], ax=axes[1])
        axes[2].imshow(image)
        axes[2].imshow(result.anomaly_map, cmap='jet', alpha=0.5)
        plt.tight_layout()
        plt.show()

## ZeroShotAnomalyDetector

In [6]:
class ZeroShotAnomalyDetector:
    def __init__(self, model, input_size: int = 224, batch_size: int = 32, device: torch.device = None, l2_normalize: bool = True, use_gpu_faiss: bool = True, index_type: str = 'flat', scoring_method: str = 'knn', k: int = 1):
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preprocessor = ImagePreprocessor(input_size=input_size, batch_size=batch_size)
        self.feature_extractor = FeatureExtractor(model, device=self.device, l2_normalize=l2_normalize)
        self.input_size = input_size
        self.scoring_method = scoring_method
        self.k = k
        self.use_gpu_faiss = use_gpu_faiss
        self.index_type = index_type
        self.feature_bank = None
        self.scorer = None
    def fit(self, normal_images_dir: str, compute_mahalanobis: bool = True):
        image_paths = self.preprocessor.load_images_from_directory(normal_images_dir)
        batches = self.preprocessor.create_batches(image_paths)
        feature_dict = self.feature_extractor.extract_from_multiple_batches(batches)
        feature_dim = self.feature_extractor.get_feature_dim()
        self.feature_bank = NormalFeatureBank(feature_dim=feature_dim, use_gpu=self.use_gpu_faiss, index_type=self.index_type)
        compute_stats = compute_mahalanobis or (self.scoring_method == 'mahalanobis')
        self.feature_bank.build_from_features(feature_dict, compute_statistics=compute_stats)
        patch_grid_size = self.feature_extractor.get_patch_grid_size(self.input_size)
        self.scorer = AnomalyScorer(feature_bank=self.feature_bank, patch_grid_size=patch_grid_size, method=self.scoring_method, k=self.k)
        return self
    def set_threshold(self, validation_images_dir: str, percentile: float = 99.0):
        validation_scores = []
        image_paths = self.preprocessor.load_images_from_directory(validation_images_dir)
        for img_path in image_paths:
            image = self.preprocessor.load_image(img_path)
            result = self.predict(image)
            validation_scores.append(result.image_score)
        self.scorer.set_threshold_from_normal_scores(validation_scores, percentile)
        return self
    def predict(self, image: Image.Image) -> AnomalyResult:
        image_tensor = self.preprocessor.preprocess_single(image)
        patch_features = self.feature_extractor.extract_features(image_tensor)
        patch_features = patch_features[0].cpu().numpy()
        result = self.scorer.score_image(patch_features, target_size=(self.input_size, self.input_size))
        return result
    def predict_from_path(self, image_path: str) -> AnomalyResult:
        image = self.preprocessor.load_image(image_path)
        return self.predict(image)
    def predict_batch(self, images: List[Image.Image]) -> List[AnomalyResult]:
        return [self.predict(img) for img in images]
    def evaluate(self, test_images_dir: str, show_samples: int = 5) -> Dict:
        image_paths = self.preprocessor.load_images_from_directory(test_images_dir)
        results = []
        for img_path in image_paths:
            image = self.preprocessor.load_image(img_path)
            result = self.predict(image)
            results.append((img_path, image, result))
        scores = [r[2].image_score for r in results]
        anomalies = [r[2].is_anomaly for r in results]
        metrics = {
            'num_images': len(results),
            'num_anomalies': sum(anomalies),
            'anomaly_rate': sum(anomalies) / len(anomalies) if anomalies else 0,
            'score_mean': np.mean(scores),
            'score_std': np.std(scores),
            'score_min': np.min(scores),
            'score_max': np.max(scores)
        }
        return metrics
    def save(self, path: str):
        self.feature_bank.save(path + '_bank')
        config = {
            'input_size': self.input_size,
            'scoring_method': self.scoring_method,
            'k': self.k,
            'threshold': self.scorer.threshold if self.scorer else None,
            'patch_grid_size': self.scorer.patch_grid_size if self.scorer else None
        }
        np.savez(path + '_config.npz', **config)
    def load(self, path: str):
        feature_dim = self.feature_extractor.get_feature_dim()
        self.feature_bank = NormalFeatureBank(feature_dim=feature_dim, use_gpu=self.use_gpu_faiss, index_type=self.index_type)
        self.feature_bank.load(path + '_bank')
        config = np.load(path + '_config.npz', allow_pickle=True)
        self.input_size = int(config['input_size'])
        self.scoring_method = str(config['scoring_method'])
        self.k = int(config['k'])
        threshold = config['threshold'].item()
        patch_grid_size = tuple(config['patch_grid_size'])
        self.scorer = AnomalyScorer(feature_bank=self.feature_bank, patch_grid_size=patch_grid_size, method=self.scoring_method, k=self.k, threshold=threshold)

## Load DINOv3 Model

In [7]:
REPO_DIR = r"C:\Users\berko\OneDrive - Danmarks Tekniske Universitet\DTU\Deep Learning\dinov3"
WEIGHTS_PATH = r"C:\Users\berko\OneDrive - Danmarks Tekniske Universitet\DTU\Deep Learning\dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth"
BACKBONE_WEIGHTS_PATH = r"C:\Users\berko\OneDrive - Danmarks Tekniske Universitet\DTU\Deep Learning\dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth"

print("Loading DINOv3 ViT-L/16 with text encoder...")

dinov3_model, tokenizer = torch.hub.load(
    REPO_DIR,
    model='dinov3_vitl16_dinotxt_tet1280d20h24l',
    source='local',
    weights=WEIGHTS_PATH,
    backbone_weights=BACKBONE_WEIGHTS_PATH
)

print("✓ Model loaded successfully")

Loading DINOv3 ViT-L/16 with text encoder...
✓ Model loaded successfully


## Zero-Shot Classification with Text Prompts

In [8]:
SCORE_THRESHOLD = 0.20
BELOW_THRESHOLD_PENALTY = -5

anomaly_descriptions = {
    'NORMAL': [
        "Transistor with three long complete legs straight vertical inserted in holes",
        "Transistor properly aligned upright body centered legs perpendicular",
        "Perfect transistor intact no damage no bent legs clean",
        "Transistor all three legs fully inserted flush contact with holes"
    ],
    'BENT_LEADS': [
        "Transistor with legs bent twisted deformed at sharp kink angles",
        "Transistor legs crooked curved irregular shapes not straight not parallel",
        "Transistor severely deformed bent kinked legs uneven geometry",
        "Transistor legs twisted bent outward different angles not vertical"
    ],
    'DAMAGED_CASE': [
        "Transistor plastic body cracked fractured split shattered broken case",
        "Transistor case missing chunks pieces jagged rough edges destroyed",
        "Transistor body damaged destroyed deteriorated cracks throughout",
        "Transistor plastic shattered broken chunks missing holes in case"
    ],
    'CUT_LEADS': [
        "Transistor legs cut close tobody trimmed shortened not full length",
        "Transistor legs fractured broken off severed at where it is close to the balck body missing length",
        "Transistor has a mismatching leg pattern one or more legs cut short",
        "Transistor legs do not reach holes too short cannot insert contact"
    ],
    'MISPLACED': [
        "Transistor body rotated tilted at severe angle sideways position",
        "Transistor leaning over lying horizontal not standing vertical upright",
        "Transistor oriented at extreme angle rotated away not perpendicular",
        "Transistor is placed horizontally"
    ]
}

all_descriptions = []
class_mapping = []
class_names = ['NORMAL', 'BENT_LEADS', 'DAMAGED_CASE', 'CUT_LEADS', 'MISPLACED']

for class_idx, class_name in enumerate(class_names):
    for desc in anomaly_descriptions[class_name]:
        all_descriptions.append(desc)
        class_mapping.append(class_idx)

dinov3_model = dinov3_model.to(device)
text_tokens = tokenizer.tokenize(all_descriptions).to(device)

with torch.no_grad():
    text_features = dinov3_model.encode_text(text_tokens)
    text_features = torch.nn.functional.normalize(text_features, p=2, dim=-1)

test_image_path = r"C:\Users\berko\OneDrive - Danmarks Tekniske Universitet\DTU\Deep Learning\transistor\test\misplaced\005.png"
image = Image.open(test_image_path).convert('RGB')

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

image_tensor = transform(image).unsqueeze(0).to(device)

with torch.no_grad():
    image_features = dinov3_model.encode_image(image_tensor)
    image_features = torch.nn.functional.normalize(image_features, p=2, dim=-1)

similarities = (image_features @ text_features.T).squeeze().cpu().numpy()
similarities_squared = np.square(similarities)

class_scores = {}
for class_idx in range(len(class_names)):
    class_desc_indices = [i for i, c in enumerate(class_mapping) if c == class_idx]
    class_desc_scores_original = [similarities[i] for i in class_desc_indices]
    class_desc_scores_squared = [similarities_squared[i] for i in class_desc_indices]
    
    penalized_scores = []
    for orig_score, sq_score in zip(class_desc_scores_original, class_desc_scores_squared):
        if orig_score < SCORE_THRESHOLD:
            penalized_score = sq_score * BELOW_THRESHOLD_PENALTY
        elif orig_score > SCORE_THRESHOLD:
            penalized_score = sq_score * (1 + (orig_score-SCORE_THRESHOLD) * 5)
        else:
            penalized_score = sq_score
        penalized_scores.append(penalized_score)
    
    if penalized_scores:
        weights = np.array(penalized_scores) / np.sum(penalized_scores)
        weighted_score = np.sum(np.array(penalized_scores) * weights)
        class_scores[class_names[class_idx]] = weighted_score
    else:
        class_scores[class_names[class_idx]] = 0.0

pred_class = max(class_scores.items(), key=lambda x: x[1])

## Evaluation on Test Set

In [9]:
import os
import time
from collections import defaultdict

BASE_TEST_PATH = r"C:\Users\berko\OneDrive - Danmarks Tekniske Universitet\DTU\Deep Learning\transistor\test"

folder_to_class = {
    'good': 'NORMAL',
    'bent_lead': 'BENT_LEAD',
    'damaged_case': 'DAMAGED_CASE',
    'cut_lead': 'CUT_LEAD',
    'misplaced': 'MISPLACED'
}

class_names = ['NORMAL', 'BENT_LEAD', 'DAMAGED_CASE', 'CUT_LEAD', 'MISPLACED']

results_by_class = defaultdict(lambda: {'correct': 0, 'total': 0})
all_correct = 0
all_total = 0
general_correct = 0
general_total = 0
penalized_correct = 0
penalized_total = 0

start_time = time.time()

for folder in os.listdir(BASE_TEST_PATH):
    folder_path = os.path.join(BASE_TEST_PATH, folder)
    if not os.path.isdir(folder_path):
        continue
    
    class_label = folder_to_class.get(folder)
    if class_label is None:
        print(f"Skipping folder '{folder}' (not mapped to class)")
        continue
    
    for fname in os.listdir(folder_path):
        if not fname.lower().endswith('.png'):
            continue
        
        img_path = os.path.join(folder_path, fname)
        image = Image.open(img_path).convert('RGB')
        
        t0 = time.time()
        
        with torch.no_grad():
            image_tensor = transform(image).unsqueeze(0).to(device)
            image_features = dinov3_model.encode_image(image_tensor)
            image_features = torch.nn.functional.normalize(image_features, p=2, dim=-1)
            similarities = (image_features @ text_features.T).squeeze().cpu().numpy()
        
        similarities_squared = np.square(similarities)
        
        class_scores = {}
        for class_idx in range(len(class_names)):
            class_desc_indices = [i for i, c in enumerate(class_mapping) if c == class_idx]
            class_desc_scores_original = [similarities[i] for i in class_desc_indices]
            class_desc_scores_squared = [similarities_squared[i] for i in class_desc_indices]
            
            penalized_scores = []
            for orig_score, sq_score in zip(class_desc_scores_original, class_desc_scores_squared):
                if orig_score < SCORE_THRESHOLD:
                    penalized_score = sq_score * BELOW_THRESHOLD_PENALTY
                elif orig_score > SCORE_THRESHOLD:
                    penalized_score = sq_score * (1 + (orig_score-SCORE_THRESHOLD) * 5)
                else:
                    penalized_score = sq_score
                penalized_scores.append(penalized_score)
            
            if penalized_scores:
                weights = np.array(penalized_scores) / np.sum(penalized_scores)
                weighted_score = np.sum(np.array(penalized_scores) * weights)
                class_scores[class_names[class_idx]] = weighted_score
            else:
                class_scores[class_names[class_idx]] = 0.0
        
        pred_class = max(class_scores.items(), key=lambda x: x[1])[0]
        
        results_by_class[class_label]['total'] += 1
        all_total += 1
        
        if pred_class == class_label:
            results_by_class[class_label]['correct'] += 1
            all_correct += 1
        
        is_true_normal = (class_label == 'NORMAL')
        is_pred_normal = (pred_class == 'NORMAL')
        is_true_anomaly = not is_true_normal
        is_pred_anomaly = not is_pred_normal
        
        if (is_true_normal and is_pred_normal) or (is_true_anomaly and is_pred_anomaly):
            general_correct += 1
        general_total += 1
        
        penalized_total += 1
        if pred_class == class_label:
            penalized_correct += 1
        elif is_true_anomaly and is_pred_normal:
            pass
        elif is_true_normal and is_pred_anomaly:
            penalized_correct += 1
        elif is_true_anomaly and is_pred_anomaly:
            penalized_correct += 1
        
        elapsed = time.time() - t0
        print(f"Image: {fname} | True: {class_label} | Predicted: {pred_class} | {'CORRECT' if pred_class == class_label else 'WRONG'} | Time: {elapsed:.3f}s")

end_time = time.time()
avg_time = (end_time - start_time) / all_total if all_total > 0 else 0.0

print("\nClassification Accuracy by Class:")
for class_label in class_names:
    correct = results_by_class[class_label]['correct']
    total = results_by_class[class_label]['total']
    acc = (correct / total * 100) if total > 0 else 0.0
    print(f"{class_label:15s}: {correct}/{total} ({acc:.2f}%)")

overall_acc = (all_correct / all_total * 100) if all_total > 0 else 0.0
general_acc = (general_correct / general_total * 100) if general_total > 0 else 0.0
penalized_acc = (penalized_correct / penalized_total * 100) if penalized_total > 0 else 0.0

print("\nSummary Table:")
print("---------------------------------------------------------------")
print(f"Overall classification accuracy (specific): {all_correct}/{all_total} ({overall_acc:.2f}%)")
print(f"General anomaly/normal accuracy:           {general_correct}/{general_total} ({general_acc:.2f}%)")
print(f"Penalized accuracy (false positives):      {penalized_correct}/{penalized_total} ({penalized_acc:.2f}%)")
print(f"Average time per image:                    {avg_time:.3f} seconds")
print("---------------------------------------------------------------")

Image: 000.png | True: BENT_LEAD | Predicted: NORMAL | WRONG | Time: 0.935s
Image: 001.png | True: BENT_LEAD | Predicted: NORMAL | WRONG | Time: 0.946s
Image: 002.png | True: BENT_LEAD | Predicted: NORMAL | WRONG | Time: 0.930s
Image: 003.png | True: BENT_LEAD | Predicted: NORMAL | WRONG | Time: 1.027s
Image: 004.png | True: BENT_LEAD | Predicted: NORMAL | WRONG | Time: 0.930s
Image: 005.png | True: BENT_LEAD | Predicted: NORMAL | WRONG | Time: 0.935s
Image: 006.png | True: BENT_LEAD | Predicted: NORMAL | WRONG | Time: 0.926s
Image: 007.png | True: BENT_LEAD | Predicted: CUT_LEAD | WRONG | Time: 0.933s
Image: 008.png | True: BENT_LEAD | Predicted: NORMAL | WRONG | Time: 0.921s
Image: 009.png | True: BENT_LEAD | Predicted: NORMAL | WRONG | Time: 0.945s
Image: 000.png | True: CUT_LEAD | Predicted: NORMAL | WRONG | Time: 0.949s
Image: 001.png | True: CUT_LEAD | Predicted: NORMAL | WRONG | Time: 0.937s
Image: 002.png | True: CUT_LEAD | Predicted: NORMAL | WRONG | Time: 0.926s
Image: 003.pn