In [3]:
"""
MVTec AD Anomaly Detection con DINOv2
=====================================

Dos m√©todos optimizados para MVTec AD (im√°genes alineadas):
1. Dense Matching (Posicional) - Comparaci√≥n 1:1 entre patches
2. Memory Bank + k-NN (PatchCore-style) - Estado del arte

Ambos m√©todos permiten configurar la capa de extracci√≥n de features.
"""

import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import scipy.ndimage as ndimage
from typing import List, Optional, Tuple, Union
import cv2
import warnings


# =============================================================================
# FUNCIONES DE NORMALIZACI√ìN Y RESIZE
# =============================================================================

def normalize_anomaly_map(
    anomaly_map: np.ndarray,
    method: str = 'minmax',
    clip_percentile: Optional[Tuple[float, float]] = None,
    robust_percentile: Tuple[float, float] = (2, 98)
) -> np.ndarray:
    """
    Normaliza un mapa de anomal√≠a al rango [0, 1].
    
    Args:
        anomaly_map: Mapa de anomal√≠a [H, W] con valores continuos
        method: M√©todo de normalizaci√≥n:
            - 'minmax': Normalizaci√≥n min-max est√°ndar
            - 'robust': Normalizaci√≥n robusta usando percentiles (evita outliers)
        clip_percentile: Tupla (min_percentile, max_percentile) para recortar
                        valores extremos antes de normalizar. Ej: (1, 99)
        robust_percentile: Percentiles a usar para normalizaci√≥n robusta.
                          Por defecto (2, 98)
    
    Returns:
        normalized_map: Mapa normalizado en [0, 1]
    
    Example:
        >>> amap = normalize_anomaly_map(raw_scores, method='robust')
        >>> amap = normalize_anomaly_map(raw_scores, clip_percentile=(1, 99))
    """
    amap = anomaly_map.astype(np.float32).copy()
    
    # Paso 1: Recortar valores extremos si se especifica
    if clip_percentile is not None:
        p_min, p_max = clip_percentile
        v_min = np.percentile(amap, p_min)
        v_max = np.percentile(amap, p_max)
        amap = np.clip(amap, v_min, v_max)
    
    # Paso 2: Normalizar seg√∫n el m√©todo
    if method == 'minmax':
        # Normalizaci√≥n min-max est√°ndar
        amap_min = amap.min()
        amap_max = amap.max()
        
        if amap_max > amap_min:
            normalized = (amap - amap_min) / (amap_max - amap_min)
        else:
            # Mapa uniforme, retornar ceros
            normalized = np.zeros_like(amap)
            warnings.warn("Mapa de anomal√≠a tiene valores uniformes (min == max)")
    
    elif method == 'robust':
        # Normalizaci√≥n robusta usando percentiles
        p_low, p_high = robust_percentile
        v_low = np.percentile(amap, p_low)
        v_high = np.percentile(amap, p_high)
        
        if v_high > v_low:
            normalized = (amap - v_low) / (v_high - v_low)
            normalized = np.clip(normalized, 0, 1)
        else:
            normalized = np.zeros_like(amap)
            warnings.warn("Percentiles robustos iguales, usando normalizaci√≥n min-max")
            return normalize_anomaly_map(amap, method='minmax')
    else:
        raise ValueError(f"M√©todo de normalizaci√≥n no v√°lido: {method}. Usar 'minmax' o 'robust'")
    
    return normalized


def resize_anomaly_map(
    anomaly_map: np.ndarray,
    target_size: Tuple[int, int],
    interpolation: str = 'bilinear'
) -> np.ndarray:
    """
    Redimensiona un mapa de anomal√≠a a un tama√±o objetivo.
    
    Args:
        anomaly_map: Mapa de anomal√≠a [H, W]
        target_size: Tama√±o objetivo (height, width)
        interpolation: M√©todo de interpolaci√≥n:
            - 'bilinear': Interpolaci√≥n bilinear (suave, por defecto)
            - 'nearest': Vecino m√°s cercano (preserva valores exactos)
            - 'cubic': Interpolaci√≥n bic√∫bica (m√°s suave)
    
    Returns:
        resized_map: Mapa redimensionado [target_H, target_W]
    """
    interp_methods = {
        'bilinear': cv2.INTER_LINEAR,
        'nearest': cv2.INTER_NEAREST,
        'cubic': cv2.INTER_CUBIC
    }
    
    if interpolation not in interp_methods:
        raise ValueError(f"Interpolaci√≥n no v√°lida: {interpolation}. Usar: {list(interp_methods.keys())}")
    
    # cv2.resize espera (width, height)
    target_wh = (target_size[1], target_size[0])
    
    resized = cv2.resize(
        anomaly_map.astype(np.float32),
        target_wh,
        interpolation=interp_methods[interpolation]
    )
    
    return resized


# =============================================================================
# CONFIGURACI√ìN DEL MODELO
# =============================================================================

class DINOv2FeatureExtractor:
    """
    Extractor de features con DINOv2 y capa configurable.
    
    Args:
        model_path: Ruta al modelo DINOv2
        layer_idx: √çndice de la capa a usar (-1 = √∫ltima, 0 = embedding, etc.)
        device: Dispositivo para inferencia ('cuda' o 'cpu')
    """
    
    def __init__(self, model_path: str, layer_idx: int = -1, device: str = None):
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.processor = AutoImageProcessor.from_pretrained(model_path)
        self.model = AutoModel.from_pretrained(model_path).to(self.device)
        self.model.eval()
        self.layer_idx = layer_idx
        
        # DINOv2-base usa patches de 14x14
        self.patch_size = 14
        
    def extract_patches(self, image: Image.Image, normalize: bool = True) -> torch.Tensor:
        """
        Extrae embeddings de patches de una imagen.
        
        Args:
            image: Imagen PIL
            normalize: Si True, normaliza los embeddings (L2)
            
        Returns:
            patches: Tensor [num_patches, hidden_dim]
        """
        inputs = self.processor(images=image, return_tensors="pt", do_rescale=True).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            
            # Seleccionar capa espec√≠fica
            if self.layer_idx == -1:
                # √öltima capa (last_hidden_state)
                hidden_states = outputs.last_hidden_state
            else:
                # Capa espec√≠fica de hidden_states
                hidden_states = outputs.hidden_states[self.layer_idx]
            
            # Remover token CLS (primer token)
            patches = hidden_states[:, 1:, :].squeeze(0)  # [num_patches, hidden_dim]
            
            if normalize:
                patches = F.normalize(patches, p=2, dim=-1)
                
        return patches
    
    def extract_patches_batch(self, images: List[Image.Image], normalize: bool = True) -> torch.Tensor:
        """
        Extrae patches de m√∫ltiples im√°genes en batch.
        
        Returns:
            patches: Tensor [batch, num_patches, hidden_dim]
        """
        inputs = self.processor(images=images, return_tensors="pt", do_rescale=True).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            
            if self.layer_idx == -1:
                hidden_states = outputs.last_hidden_state
            else:
                hidden_states = outputs.hidden_states[self.layer_idx]
            
            patches = hidden_states[:, 1:, :]  # [batch, num_patches, hidden_dim]
            
            if normalize:
                patches = F.normalize(patches, p=2, dim=-1)
                
        return patches
    
    def get_grid_size(self, image: Image.Image) -> Tuple[int, int]:
        """Retorna el tama√±o del grid de patches (h, w)."""
        inputs = self.processor(images=image, return_tensors="pt")
        h = inputs['pixel_values'].shape[-2] // self.patch_size
        w = inputs['pixel_values'].shape[-1] // self.patch_size
        return h, w


# =============================================================================
# M√âTODO 1: DENSE MATCHING (POSICIONAL)
# =============================================================================

class DenseMatchingDetector:
    """
    Detector de anomal√≠as por Dense Matching (correspondencia posicional 1:1).
    
    Ideal para MVTec AD donde las im√°genes est√°n bien alineadas.
    Compara cada patch de la imagen test con el patch en la misma posici√≥n
    de la imagen de referencia.
    
    Args:
        extractor: Instancia de DINOv2FeatureExtractor
    """
    
    def __init__(self, extractor: DINOv2FeatureExtractor):
        self.extractor = extractor
        
    def compute_anomaly_map(
        self, 
        test_image: Image.Image, 
        reference_image: Image.Image,
        smooth_sigma: float = 0.8
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Calcula mapa de anomal√≠a por comparaci√≥n posicional 1:1.
        
        Args:
            test_image: Imagen a evaluar
            reference_image: Imagen de referencia (sin defectos)
            smooth_sigma: Sigma para suavizado Gaussiano
            
        Returns:
            anomaly_map: Mapa de anomal√≠a [H, W] sin suavizar
            anomaly_map_smooth: Mapa de anomal√≠a [H, W] suavizado
        """
        # Extraer patches de ambas im√°genes
        test_patches = self.extractor.extract_patches(test_image)      # [N, D]
        ref_patches = self.extractor.extract_patches(reference_image)  # [N, D]
        
        # Similitud coseno por patch (correspondencia posicional)
        cosine_sim = (test_patches * ref_patches).sum(dim=-1)  # [N]
        
        # Anomal√≠a = 1 - similitud
        anomaly_scores = (1 - cosine_sim).cpu().numpy()
        
        # Reshape a grid 2D
        h, w = self.extractor.get_grid_size(test_image)
        anomaly_map = anomaly_scores.reshape(h, w)
        
        # Suavizado
        anomaly_map_smooth = ndimage.gaussian_filter(anomaly_map, sigma=smooth_sigma)
        
        return anomaly_map, anomaly_map_smooth
    
    def compute_anomaly_map_multi_ref(
        self,
        test_image: Image.Image,
        reference_images: List[Image.Image],
        aggregation: str = 'min',
        smooth_sigma: float = 0.8
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Calcula mapa de anomal√≠a contra m√∫ltiples referencias.
        
        Args:
            test_image: Imagen a evaluar
            reference_images: Lista de im√°genes de referencia
            aggregation: 'min' (menor distancia) o 'mean' (promedio)
            smooth_sigma: Sigma para suavizado
            
        Returns:
            anomaly_map, anomaly_map_smooth
        """
        test_patches = self.extractor.extract_patches(test_image)  # [N, D]
        
        all_scores = []
        for ref_img in reference_images:
            ref_patches = self.extractor.extract_patches(ref_img)
            cosine_sim = (test_patches * ref_patches).sum(dim=-1)
            scores = 1 - cosine_sim
            all_scores.append(scores)
        
        all_scores = torch.stack(all_scores, dim=0)  # [num_refs, N]
        
        if aggregation == 'min':
            anomaly_scores = all_scores.min(dim=0)[0]
        else:
            anomaly_scores = all_scores.mean(dim=0)
        
        h, w = self.extractor.get_grid_size(test_image)
        anomaly_map = anomaly_scores.cpu().numpy().reshape(h, w)
        anomaly_map_smooth = ndimage.gaussian_filter(anomaly_map, sigma=smooth_sigma)
        
        return anomaly_map, anomaly_map_smooth
    
    def visualize(
        self,
        test_image: Image.Image,
        reference_image: Image.Image,
        anomaly_map: np.ndarray,
        title: str = "Dense Matching - Detecci√≥n de Anomal√≠as"
    ):
        """Visualiza resultado de detecci√≥n."""
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))
        
        axes[0].imshow(reference_image)
        axes[0].set_title("Referencia (Sin defectos)")
        axes[0].axis('off')
        
        axes[1].imshow(test_image)
        axes[1].set_title("Imagen Test")
        axes[1].axis('off')
        
        axes[2].imshow(test_image)
        im = axes[2].imshow(
            anomaly_map, 
            cmap='jet', 
            alpha=0.5, 
            extent=(0, test_image.width, test_image.height, 0)
        )
        axes[2].set_title("Mapa de Anomal√≠a")
        axes[2].axis('off')
        plt.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)
        
        plt.suptitle(title, fontsize=14)
        plt.tight_layout()
        plt.show()
        
        return fig


# =============================================================================
# M√âTODO 2: MEMORY BANK + k-NN (PATCHCORE-STYLE)
# =============================================================================

class MemoryBankDetector:
    """
    Detector de anomal√≠as estilo PatchCore con Memory Bank + k-NN.
    
    Estado del arte para MVTec AD. Construye un banco de memoria con
    patches de im√°genes normales y detecta anomal√≠as buscando patches
    que no tienen vecinos cercanos en el banco.
    
    Args:
        extractor: Instancia de DINOv2FeatureExtractor
        k: N√∫mero de vecinos m√°s cercanos para scoring
        coreset_ratio: Ratio de subsampling del memory bank (1.0 = sin subsampling)
    """
    
    def __init__(
        self, 
        extractor: DINOv2FeatureExtractor, 
        k: int = 1,
        coreset_ratio: float = 1.0
    ):
        self.extractor = extractor
        self.k = k
        self.coreset_ratio = coreset_ratio
        self.memory_bank = None
        
    def build_memory_bank(self, good_images: List[Image.Image], verbose: bool = True):
        """
        Construye el banco de memoria con patches de im√°genes sin defectos.
        
        Args:
            good_images: Lista de im√°genes PIL sin defectos (training set)
            verbose: Si True, muestra progreso
        """
        all_patches = []
        
        for i, img in enumerate(good_images):
            patches = self.extractor.extract_patches(img)  # [N, D]
            all_patches.append(patches)
            
            if verbose and (i + 1) % 10 == 0:
                print(f"Procesadas {i + 1}/{len(good_images)} im√°genes")
        
        # Concatenar todos los patches
        self.memory_bank = torch.cat(all_patches, dim=0)  # [total_patches, D]
        
        # Coreset subsampling (opcional, para reducir memoria)
        if self.coreset_ratio < 1.0:
            n_samples = int(len(self.memory_bank) * self.coreset_ratio)
            indices = torch.randperm(len(self.memory_bank))[:n_samples]
            self.memory_bank = self.memory_bank[indices]
        
        if verbose:
            print(f"Memory Bank construido: {self.memory_bank.shape[0]} patches, "
                  f"dim={self.memory_bank.shape[1]}")
    
    def compute_anomaly_map(
        self, 
        test_image: Image.Image,
        smooth_sigma: float = 0.8
    ) -> Tuple[np.ndarray, np.ndarray, float]:
        """
        Calcula mapa de anomal√≠a comparando contra memory bank.
        
        Args:
            test_image: Imagen a evaluar
            smooth_sigma: Sigma para suavizado
            
        Returns:
            anomaly_map: Mapa [H, W] sin suavizar
            anomaly_map_smooth: Mapa [H, W] suavizado
            image_score: Score de anomal√≠a a nivel de imagen
        """
        if self.memory_bank is None:
            raise RuntimeError("Primero ejecuta build_memory_bank()")
        
        # Extraer patches de imagen test
        test_patches = self.extractor.extract_patches(test_image)  # [N, D]
        
        # Calcular similitud con todo el memory bank
        # sim[i, j] = similitud entre test_patch[i] y memory_patch[j]
        sim_matrix = torch.mm(test_patches, self.memory_bank.t())  # [N, memory_size]
        
        # k-NN: obtener k vecinos m√°s similares
        topk_sim, _ = sim_matrix.topk(self.k, dim=1)  # [N, k]
        
        # Anomal√≠a = 1 - similitud promedio de k vecinos
        anomaly_scores = 1 - topk_sim.mean(dim=1)  # [N]
        
        # Reshape a grid 2D
        h, w = self.extractor.get_grid_size(test_image)
        anomaly_map = anomaly_scores.cpu().numpy().reshape(h, w)
        
        # Suavizado
        anomaly_map_smooth = ndimage.gaussian_filter(anomaly_map, sigma=smooth_sigma)
        
        # Score a nivel de imagen (m√°ximo score de anomal√≠a)
        image_score = anomaly_map_smooth.max()
        
        return anomaly_map, anomaly_map_smooth, image_score
    
    def predict_batch(
        self,
        test_images: List[Image.Image],
        smooth_sigma: float = 0.8
    ) -> Tuple[List[np.ndarray], List[float]]:
        """
        Predice anomal√≠as para un batch de im√°genes.
        
        Returns:
            anomaly_maps: Lista de mapas de anomal√≠a
            image_scores: Lista de scores por imagen
        """
        anomaly_maps = []
        image_scores = []
        
        for img in test_images:
            _, amap_smooth, score = self.compute_anomaly_map(img, smooth_sigma)
            anomaly_maps.append(amap_smooth)
            image_scores.append(score)
        
        return anomaly_maps, image_scores
    
    def visualize(
        self,
        test_image: Image.Image,
        anomaly_map: np.ndarray,
        image_score: float,
        title: str = "Memory Bank + k-NN - Detecci√≥n de Anomal√≠as"
    ):
        """Visualiza resultado de detecci√≥n."""
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        axes[0].imshow(test_image)
        axes[0].set_title("Imagen Test")
        axes[0].axis('off')
        
        axes[1].imshow(test_image)
        im = axes[1].imshow(
            anomaly_map, 
            cmap='jet', 
            alpha=0.5, 
            extent=(0, test_image.width, test_image.height, 0)
        )
        axes[1].set_title(f"Mapa de Anomal√≠a (Score: {image_score:.4f})")
        axes[1].axis('off')
        plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
        
        plt.suptitle(title, fontsize=14)
        plt.tight_layout()
        plt.show()
        
        return fig


# =============================================================================
# UTILIDADES DE VISUALIZACI√ìN
# =============================================================================

def visualize_layer_comparison(
    extractor: DINOv2FeatureExtractor,
    image: Image.Image,
    layers_to_compare: List[int] = [0, 3, 6, 9, 12, -1]
):
    """
    Compara visualizaci√≥n PCA de diferentes capas del modelo.
    
    √ötil para elegir la capa √≥ptima para extracci√≥n de features.
    """
    original_layer = extractor.layer_idx
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for ax, layer_idx in zip(axes, layers_to_compare):
        extractor.layer_idx = layer_idx
        patches = extractor.extract_patches(image, normalize=False).cpu().numpy()
        
        # PCA a 3 componentes para RGB
        pca = PCA(n_components=3)
        pca_result = pca.fit_transform(patches)
        
        # Normalizar a [0, 1]
        pca_min = pca_result.min(axis=0)
        pca_max = pca_result.max(axis=0)
        pca_norm = (pca_result - pca_min) / (pca_max - pca_min + 1e-8)
        
        h, w = extractor.get_grid_size(image)
        pca_image = pca_norm.reshape(h, w, 3)
        
        layer_name = "√öltima" if layer_idx == -1 else str(layer_idx)
        ax.imshow(pca_image)
        ax.set_title(f"Capa {layer_name}")
        ax.axis('off')
    
    extractor.layer_idx = original_layer
    plt.suptitle("Comparaci√≥n de Features por Capa (PCA‚ÜíRGB)", fontsize=14)
    plt.tight_layout()
    plt.show()
    
    return fig


def upsample_anomaly_map(
    anomaly_map: np.ndarray,
    target_size: Tuple[int, int],
    mode: str = 'bilinear'
) -> np.ndarray:
    """
    Escala el mapa de anomal√≠a al tama√±o de la imagen original.
    
    Args:
        anomaly_map: Mapa de anomal√≠a [H, W]
        target_size: (height, width) objetivo
        mode: 'bilinear' o 'nearest'
    """
    amap_tensor = torch.from_numpy(anomaly_map).unsqueeze(0).unsqueeze(0).float()
    
    # align_corners solo para modos que lo soporten
    if mode in ('nearest', 'area'):
        upsampled = F.interpolate(amap_tensor, size=target_size, mode=mode)
    else:
        upsampled = F.interpolate(amap_tensor, size=target_size, mode=mode, align_corners=False)
    
    return upsampled.squeeze().numpy()


# =============================================================================
# M√âTRICAS DE EVALUACI√ìN
# =============================================================================

class AnomalyEvaluator:
    """
    Evaluador de m√©tricas para detecci√≥n de anomal√≠as.
    
    Calcula m√©tricas pixel-level y region-level comparando
    mapas de anomal√≠a predichos contra ground truth masks.
    
    M√©tricas implementadas:
    - Pixel-level: IoU, Dice, Precision, Recall, F1
    - Region-level: PRO (Per-Region Overlap)
    
    Caracter√≠sticas mejoradas:
    - auto_normalize: Normalizaci√≥n autom√°tica al rango [0, 1]
    - Resize autom√°tico al tama√±o del ground truth
    - Estad√≠sticas de diagn√≥stico en resultados
    """
    
    def __init__(self, threshold: float = 0.5, auto_normalize: bool = True):
        """
        Args:
            threshold: Umbral para binarizar el mapa de anomal√≠a (0-1).
                      Se aplica DESPU√âS de normalizar al rango [0, 1].
            auto_normalize: Si True, normaliza autom√°ticamente el mapa
                           al rango [0, 1] antes de aplicar el umbral.
        """
        self.threshold = threshold
        self.auto_normalize = auto_normalize
    
    @staticmethod
    def load_ground_truth(gt_path: str) -> np.ndarray:
        """
        Carga una m√°scara ground truth como array binario.
        
        Args:
            gt_path: Ruta a la imagen de ground truth
            
        Returns:
            mask: Array binario [H, W] con valores 0 o 1
        """
        gt_image = Image.open(gt_path).convert('L')
        gt_array = np.array(gt_image)
        # Binarizar (MVTec usa 255 para anomal√≠a, 0 para normal)
        return (gt_array > 127).astype(np.float32)
    
    def binarize_anomaly_map(
        self, 
        anomaly_map: np.ndarray, 
        threshold: float = None
    ) -> np.ndarray:
        """
        Binariza un mapa de anomal√≠a usando el umbral.
        
        Args:
            anomaly_map: Mapa de anomal√≠a [H, W] con valores en [0, 1]
            threshold: Umbral (usa self.threshold si es None)
            
        Returns:
            binary_mask: M√°scara binaria [H, W]
        """
        if threshold is None:
            threshold = self.threshold
        
        # Normalizar a [0, 1] si es necesario
        amap_min = anomaly_map.min()
        amap_max = anomaly_map.max()
        if amap_max > amap_min:
            normalized = (anomaly_map - amap_min) / (amap_max - amap_min)
        else:
            normalized = np.zeros_like(anomaly_map)
        
        return (normalized >= threshold).astype(np.float32)
    
    # -------------------------------------------------------------------------
    # PIXEL-LEVEL METRICS
    # -------------------------------------------------------------------------
    
    def compute_pixel_metrics(
        self, 
        pred_mask: np.ndarray, 
        gt_mask: np.ndarray
    ) -> dict:
        """
        Calcula m√©tricas a nivel de pixel.
        
        Args:
            pred_mask: M√°scara predicha binaria [H, W]
            gt_mask: Ground truth binario [H, W]
            
        Returns:
            dict con: IoU, Dice, Precision, Recall, F1
        """
        # Asegurar mismo tama√±o
        if pred_mask.shape != gt_mask.shape:
            pred_mask = upsample_anomaly_map(
                pred_mask, 
                target_size=gt_mask.shape,
                mode='nearest'
            )
        
        pred = pred_mask.flatten().astype(bool)
        gt = gt_mask.flatten().astype(bool)
        
        # Componentes de la matriz de confusi√≥n
        tp = np.sum(pred & gt)        # True Positives
        fp = np.sum(pred & ~gt)       # False Positives
        fn = np.sum(~pred & gt)       # False Negatives
        tn = np.sum(~pred & ~gt)      # True Negatives
        
        # M√©tricas
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
        
        intersection = tp
        union = tp + fp + fn
        iou = intersection / union if union > 0 else 0.0
        
        dice = 2 * intersection / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0.0
        
        return {
            'IoU': iou,
            'Dice': dice,
            'Precision': precision,
            'Recall': recall,
            'F1': f1,
            'TP': int(tp),
            'FP': int(fp),
            'FN': int(fn),
            'TN': int(tn)
        }
    
    # -------------------------------------------------------------------------
    # REGION-LEVEL METRICS (PRO)
    # -------------------------------------------------------------------------
    
    def compute_pro_single(
        self,
        pred_mask: np.ndarray,
        gt_mask: np.ndarray
    ) -> Tuple[float, int, List[float]]:
        """
        Calcula PRO para una sola imagen con un umbral fijo.
        
        F√≥rmula del paper:
        PRO = (1/N) * Œ£_i Œ£_k (|P_i ‚à© C_{i,k}| / |C_{i,k}|)
        
        Donde:
        - N = n√∫mero total de regiones conectadas en el ground truth
        - C_{i,k} = p√≠xeles del componente k en imagen i
        - P_i = p√≠xeles predichos como an√≥malos
        
        Args:
            pred_mask: M√°scara binaria de predicci√≥n [H, W]
            gt_mask: Ground truth binario [H, W]
            
        Returns:
            pro_score: Score PRO para esta imagen (promedio sobre regiones)
            num_regions: N√∫mero de regiones conectadas
            region_overlaps: Lista de overlaps por regi√≥n
        """
        from scipy import ndimage as ndi
        
        # Asegurar mismo tama√±o
        if pred_mask.shape != gt_mask.shape:
            pred_mask = upsample_anomaly_map(
                pred_mask,
                target_size=gt_mask.shape,
                mode='nearest'
            )
        
        # Encontrar regiones conectadas en ground truth
        labeled_gt, num_regions = ndi.label(gt_mask > 0)
        
        if num_regions == 0:
            return 1.0, 0, []
        
        # Calcular overlap para cada regi√≥n: |P ‚à© C_k| / |C_k|
        region_overlaps = []
        pred_binary = pred_mask > 0
        
        for region_id in range(1, num_regions + 1):
            region_mask = (labeled_gt == region_id)
            region_size = np.sum(region_mask)  # |C_k|
            
            if region_size > 0:
                intersection = np.sum(pred_binary & region_mask)  # |P ‚à© C_k|
                overlap = intersection / region_size
                region_overlaps.append(overlap)
        
        # PRO para esta imagen = promedio sobre regiones
        pro_score = np.mean(region_overlaps) if region_overlaps else 0.0
        
        return pro_score, num_regions, region_overlaps
    
    def compute_pro(
        self, 
        anomaly_map: np.ndarray, 
        gt_mask: np.ndarray,
        threshold: float = None
    ) -> Tuple[float, int, List[float]]:
        """
        Calcula PRO para un mapa de anomal√≠a usando un umbral.
        
        Args:
            anomaly_map: Mapa de anomal√≠a [H, W] (valores continuos)
            gt_mask: Ground truth binario [H, W]
            threshold: Umbral para binarizar (usa self.threshold si None)
            
        Returns:
            pro_score: Score PRO
            num_regions: N√∫mero de regiones
            region_overlaps: Overlaps por regi√≥n
        """
        # Binarizar mapa de anomal√≠a
        pred_mask = self.binarize_anomaly_map(anomaly_map, threshold)
        return self.compute_pro_single(pred_mask, gt_mask)
    
    def compute_au_pro(
        self, 
        anomaly_map: np.ndarray, 
        gt_mask: np.ndarray,
        num_thresholds: int = 100,
        fpr_limit: float = 0.3
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
        """
        Calcula AU-PRO (Area Under PRO curve).
        
        Calcula la curva PRO vs FPR para m√∫ltiples umbrales y el √°rea
        bajo la curva hasta un l√≠mite de FPR (t√≠picamente 0.3).
        
        Args:
            anomaly_map: Mapa de anomal√≠a [H, W] (valores continuos)
            gt_mask: Ground truth binario [H, W]
            num_thresholds: N√∫mero de umbrales para la curva
            fpr_limit: L√≠mite de FPR para integraci√≥n (default 0.3)
            
        Returns:
            thresholds: Umbrales usados
            fpr_values: FPR para cada umbral
            pro_values: PRO para cada umbral
            au_pro: √Årea bajo la curva PRO normalizada
        """
        from scipy import ndimage as ndi
        
        # Asegurar mismo tama√±o
        if anomaly_map.shape != gt_mask.shape:
            anomaly_map = upsample_anomaly_map(
                anomaly_map, 
                target_size=gt_mask.shape,
                mode='bilinear'
            )
        
        # Normalizar mapa a [0, 1]
        amap_min = anomaly_map.min()
        amap_max = anomaly_map.max()
        if amap_max > amap_min:
            anomaly_map_norm = (anomaly_map - amap_min) / (amap_max - amap_min)
        else:
            anomaly_map_norm = np.zeros_like(anomaly_map)
        
        # Encontrar regiones conectadas en ground truth
        labeled_gt, num_regions = ndi.label(gt_mask > 0)
        
        if num_regions == 0:
            return np.array([0.0]), np.array([0.0]), np.array([1.0]), 1.0
        
        # Calcular PRO y FPR para m√∫ltiples umbrales
        thresholds = np.linspace(0, 1, num_thresholds)
        fpr_values = []
        pro_values = []
        
        total_normal_pixels = np.sum(gt_mask == 0)
        
        for thresh in thresholds:
            pred_mask = (anomaly_map_norm >= thresh)
            
            # FPR: False Positive Rate en p√≠xeles normales
            if total_normal_pixels > 0:
                fp = np.sum(pred_mask & (gt_mask == 0))
                fpr = fp / total_normal_pixels
            else:
                fpr = 0.0
            
            # PRO: (1/N) * Œ£_k (|P ‚à© C_k| / |C_k|)
            region_overlaps = []
            for region_id in range(1, num_regions + 1):
                region_mask = (labeled_gt == region_id)
                region_size = np.sum(region_mask)
                
                if region_size > 0:
                    intersection = np.sum(pred_mask & region_mask)
                    overlap = intersection / region_size
                    region_overlaps.append(overlap)
            
            pro = np.mean(region_overlaps) if region_overlaps else 0.0
            
            fpr_values.append(fpr)
            pro_values.append(pro)
        
        thresholds = np.array(thresholds)
        fpr_values = np.array(fpr_values)
        pro_values = np.array(pro_values)
        
        # Calcular AU-PRO (√°rea bajo curva hasta FPR limit)
        valid_idx = fpr_values <= fpr_limit
        
        if np.sum(valid_idx) > 1:
            # Ordenar por FPR e integrar
            sorted_idx = np.argsort(fpr_values[valid_idx])
            fpr_sorted = fpr_values[valid_idx][sorted_idx]
            pro_sorted = pro_values[valid_idx][sorted_idx]
            au_pro = np.trapz(pro_sorted, fpr_sorted) / fpr_limit
        else:
            au_pro = 0.0
        
        return thresholds, fpr_values, pro_values, au_pro
    
    def evaluate(
        self,
        anomaly_map: np.ndarray,
        gt_mask: np.ndarray,
        threshold: float = None
    ) -> dict:
        """
        Eval√∫a todas las m√©tricas para un mapa de anomal√≠a.
        
        El proceso mejorado:
        1. Guarda estad√≠sticas originales para diagn√≥stico
        2. Resize autom√°tico al tama√±o del GT si difieren
        3. Normaliza al rango [0, 1] (si auto_normalize=True)
        4. Aplica umbral sobre valores normalizados
        
        Args:
            anomaly_map: Mapa de anomal√≠a [H, W] (valores crudos del modelo)
            gt_mask: Ground truth binario [H, W]
            threshold: Umbral para binarizaci√≥n (0-1). Usa self.threshold si es None.
            
        Returns:
            dict con m√©tricas + estad√≠sticas de diagn√≥stico:
            - M√©tricas est√°ndar: IoU, Dice, Precision, Recall, F1, PRO, AU-PRO
            - Estad√≠sticas: orig_min, orig_max, orig_mean, normalized
        """
        use_threshold = threshold if threshold is not None else self.threshold
        
        # =====================================================================
        # PASO 1: Guardar estad√≠sticas originales para diagn√≥stico
        # =====================================================================
        orig_min = float(anomaly_map.min())
        orig_max = float(anomaly_map.max())
        orig_mean = float(anomaly_map.mean())
        orig_std = float(anomaly_map.std())
        
        # =====================================================================
        # PASO 2: Resize autom√°tico al tama√±o del ground truth
        # =====================================================================
        if anomaly_map.shape != gt_mask.shape:
            anomaly_map_resized = resize_anomaly_map(
                anomaly_map,
                target_size=gt_mask.shape,
                interpolation='bilinear'
            )
        else:
            anomaly_map_resized = anomaly_map.copy()
        
        # =====================================================================
        # PASO 3: Normalizar al rango [0, 1] antes de aplicar umbral
        # =====================================================================
        if self.auto_normalize:
            anomaly_map_normalized = normalize_anomaly_map(
                anomaly_map_resized,
                method='minmax'
            )
        else:
            anomaly_map_normalized = anomaly_map_resized
        
        # =====================================================================
        # PASO 4: Binarizar usando el umbral sobre valores normalizados [0, 1]
        # =====================================================================
        pred_mask = (anomaly_map_normalized >= use_threshold).astype(np.float32)
        
        # =====================================================================
        # PASO 5: Calcular m√©tricas
        # =====================================================================
        # Pixel-level metrics
        pixel_metrics = self.compute_pixel_metrics(pred_mask, gt_mask)
        
        # Region-level PRO (para el umbral seleccionado)
        pro_score, num_regions, region_overlaps = self.compute_pro_single(pred_mask, gt_mask)
        
        # AU-PRO (√°rea bajo la curva PRO-FPR)
        _, _, _, au_pro = self.compute_au_pro(anomaly_map_normalized, gt_mask)
        
        return {
            **pixel_metrics,
            'PRO': pro_score,
            'AU-PRO': au_pro,
            'Num_Regions': num_regions,
            'Threshold': use_threshold,
            # Estad√≠sticas de diagn√≥stico
            'orig_min': orig_min,
            'orig_max': orig_max,
            'orig_mean': orig_mean,
            'orig_std': orig_std,
            'normalized': self.auto_normalize,
            # Mapa normalizado para visualizaci√≥n
            '_anomaly_map_normalized': anomaly_map_normalized,
            '_pred_mask': pred_mask
        }
    
    def visualize_comparison(
        self,
        test_image: Image.Image,
        anomaly_map: np.ndarray,
        gt_mask: np.ndarray,
        metrics: dict,
        title: str = "Comparaci√≥n: Predicci√≥n vs Ground Truth",
        show_original_values: bool = True
    ):
        """
        Visualiza predicci√≥n vs ground truth con m√©tricas.
        
        Muestra 5 paneles:
        1. Imagen original
        2. Ground Truth
        3. Mapa original (valores sin normalizar)
        4. Mapa normalizado [0,1]
        5. Predicci√≥n binarizada
        """
        # Preparar datos
        # Resize al tama√±o de la imagen para visualizaci√≥n
        if anomaly_map.shape != gt_mask.shape:
            amap_resized = resize_anomaly_map(anomaly_map, gt_mask.shape)
        else:
            amap_resized = anomaly_map.copy()
        
        # Guardar valores originales
        amap_original = amap_resized.copy()
        orig_min, orig_max = amap_original.min(), amap_original.max()
        
        # Normalizar
        amap_normalized = normalize_anomaly_map(amap_resized, method='minmax')
        
        # Crear figura
        n_cols = 5 if show_original_values else 4
        fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 5))
        
        ax_idx = 0
        
        # Panel 1: Imagen original
        axes[ax_idx].imshow(test_image)
        axes[ax_idx].set_title("Imagen Test")
        axes[ax_idx].axis('off')
        ax_idx += 1
        
        # Panel 2: Ground Truth
        axes[ax_idx].imshow(test_image)
        axes[ax_idx].imshow(
            resize_anomaly_map(gt_mask.astype(float), (test_image.height, test_image.width)),
            cmap='Reds', alpha=0.5,
            extent=(0, test_image.width, test_image.height, 0)
        )
        axes[ax_idx].set_title(f"Ground Truth\n({gt_mask.sum()} p√≠xeles)")
        axes[ax_idx].axis('off')
        ax_idx += 1
        
        if show_original_values:
            # Panel 3: Mapa original (sin normalizar)
            axes[ax_idx].imshow(test_image)
            im_orig = axes[ax_idx].imshow(
                amap_original, 
                cmap='jet', 
                alpha=0.5,
                extent=(0, test_image.width, test_image.height, 0)
            )
            axes[ax_idx].set_title(f"Mapa Original\n[{orig_min:.4f}, {orig_max:.4f}]")
            axes[ax_idx].axis('off')
            plt.colorbar(im_orig, ax=axes[ax_idx], fraction=0.046, pad=0.04)
            ax_idx += 1
        
        # Panel 4: Mapa normalizado
        axes[ax_idx].imshow(test_image)
        im_norm = axes[ax_idx].imshow(
            amap_normalized, 
            cmap='jet', 
            alpha=0.5,
            vmin=0, vmax=1,
            extent=(0, test_image.width, test_image.height, 0)
        )
        axes[ax_idx].set_title(f"Mapa Normalizado\n[0.0, 1.0]")
        axes[ax_idx].axis('off')
        plt.colorbar(im_norm, ax=axes[ax_idx], fraction=0.046, pad=0.04)
        ax_idx += 1
        
        # Panel 5: Predicci√≥n binarizada
        pred_binary = amap_normalized >= self.threshold
        axes[ax_idx].imshow(test_image)
        axes[ax_idx].imshow(
            pred_binary, 
            cmap='Blues', 
            alpha=0.5,
            extent=(0, test_image.width, test_image.height, 0)
        )
        n_pred = pred_binary.sum()
        axes[ax_idx].set_title(f"Predicci√≥n Binaria\n(œÑ={self.threshold}, {n_pred} p√≠xeles)")
        axes[ax_idx].axis('off')
        
        # M√©tricas como texto
        metrics_text = (
            f"IoU: {metrics['IoU']:.3f} | "
            f"Dice: {metrics['Dice']:.3f} | "
            f"F1: {metrics['F1']:.3f} | "
            f"Precision: {metrics['Precision']:.3f} | "
            f"Recall: {metrics['Recall']:.3f} | "
            f"AU-PRO: {metrics['AU-PRO']:.3f}"
        )
        
        plt.figtext(0.5, 0.02, metrics_text, ha='center', fontsize=10,
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        plt.suptitle(title, fontsize=14)
        plt.tight_layout(rect=[0, 0.06, 1, 0.96])
        plt.show()
        
        return fig


def load_test_with_ground_truth(
    test_folder: str,
    gt_folder: str,
    n_images: int = None
) -> List[Tuple[Image.Image, np.ndarray, str]]:
    """
    Carga im√°genes de test junto con sus ground truth masks.
    
    Args:
        test_folder: Carpeta con im√°genes de test (ej: .../test/broken_large)
        gt_folder: Carpeta con ground truth (ej: .../ground_truth/broken_large)
        n_images: N√∫mero de im√°genes a cargar (None = todas)
        
    Returns:
        Lista de tuplas (test_image, gt_mask, filename)
    """
    import os
    
    extensions = ('.png', '.jpg', '.jpeg', '.bmp')
    
    test_files = sorted([
        f for f in os.listdir(test_folder)
        if f.lower().endswith(extensions)
    ])
    
    if n_images is not None:
        test_files = test_files[:n_images]
    
    results = []
    for filename in test_files:
        # Cargar imagen test
        test_path = os.path.join(test_folder, filename)
        test_img = Image.open(test_path).convert("RGB")
        
        # Cargar ground truth (mismo nombre, extensi√≥n _mask.png en MVTec)
        gt_filename = filename.replace('.png', '_mask.png')
        gt_path = os.path.join(gt_folder, gt_filename)
        
        if os.path.exists(gt_path):
            gt_mask = AnomalyEvaluator.load_ground_truth(gt_path)
        else:
            # Intentar con el mismo nombre
            gt_path = os.path.join(gt_folder, filename)
            if os.path.exists(gt_path):
                gt_mask = AnomalyEvaluator.load_ground_truth(gt_path)
            else:
                print(f"‚ö†Ô∏è Ground truth no encontrado para {filename}")
                gt_mask = None
        
        if gt_mask is not None:
            results.append((test_img, gt_mask, filename))
    
    print(f"Cargados {len(results)} pares (test, ground_truth) de {test_folder}")
    return results


# =============================================================================
# EJEMPLO DE USO
# =============================================================================

def load_images_from_folder(
    folder_path: str, 
    n_images: int = None, 
    extensions: tuple = ('.png', '.jpg', '.jpeg', '.bmp')
) -> List[Image.Image]:
    """
    Carga N im√°genes de una carpeta.
    
    Args:
        folder_path: Ruta a la carpeta con im√°genes
        n_images: N√∫mero de im√°genes a cargar (None = todas)
        extensions: Extensiones de archivo v√°lidas
        
    Returns:
        Lista de im√°genes PIL
    """
    import os
    
    # Obtener lista de archivos de imagen
    image_files = sorted([
        f for f in os.listdir(folder_path) 
        if f.lower().endswith(extensions)
    ])
    
    # Limitar a N im√°genes si se especifica
    if n_images is not None:
        image_files = image_files[:n_images]
    
    # Cargar im√°genes
    images = []
    for filename in image_files:
        img_path = os.path.join(folder_path, filename)
        img = Image.open(img_path).convert("RGB")
        images.append(img)
    
    print(f"Cargadas {len(images)} im√°genes de {folder_path}")
    return images


# =============================================================================
# EVALUADOR COMPLETO MVTEC AD
# =============================================================================

class MVTecDatasetEvaluator:
    """
    Evaluador completo del dataset MVTec AD.
    
    Itera sobre todas las categor√≠as y tipos de anomal√≠a,
    calcula m√©tricas globales, por clase y por anomal√≠a.
    
    Args:
        dataset_path: Ruta base al dataset MVTec AD (contiene las 15 carpetas de categor√≠as)
        model_path: Ruta al modelo DINOv2
        layer_idx: √çndice de la capa a usar (-1 = √∫ltima)
        n_good_images: N√∫mero de im√°genes "good" para memory bank (None = todas)
        k: N√∫mero de vecinos para k-NN
        coreset_ratio: Ratio de subsampling del memory bank
        threshold: Umbral para binarizaci√≥n
        auto_normalize: Si True, normaliza mapas de anomal√≠a autom√°ticamente
    """
    
    # Lista de las 15 categor√≠as de MVTec AD
    CATEGORIES = [
        'bottle', 'cable', 'capsule', 'carpet', 'grid',
        'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
        'tile', 'toothbrush', 'transistor', 'wood', 'zipper'
    ]
    
    def __init__(
        self,
        dataset_path: str,
        model_path: str,
        layer_idx: int = -1,
        n_good_images: int = None,
        k: int = 1,
        coreset_ratio: float = 1.0,
        threshold: float = 0.5,
        auto_normalize: bool = True,
        device: str = None
    ):
        import os
        
        self.dataset_path = dataset_path
        self.model_path = model_path
        self.layer_idx = layer_idx
        self.n_good_images = n_good_images
        self.k = k
        self.coreset_ratio = coreset_ratio
        self.threshold = threshold
        self.auto_normalize = auto_normalize
        self.device = device
        
        # Crear extractor de features
        self.extractor = DINOv2FeatureExtractor(
            model_path=model_path,
            layer_idx=layer_idx,
            device=device
        )
        
        # Evaluador de m√©tricas
        self.evaluator = AnomalyEvaluator(
            threshold=threshold,
            auto_normalize=auto_normalize
        )
        
        # Descubrir categor√≠as disponibles
        self.available_categories = self._discover_categories()
        print(f"üîç Encontradas {len(self.available_categories)} categor√≠as en {dataset_path}")
    
    def _discover_categories(self) -> List[str]:
        """Descubre las categor√≠as disponibles en el dataset."""
        import os
        categories = []
        for cat in self.CATEGORIES:
            cat_path = os.path.join(self.dataset_path, cat)
            if os.path.isdir(cat_path):
                categories.append(cat)
        return categories
    
    def _get_anomaly_types(self, category: str) -> List[str]:
        """
        Obtiene los tipos de anomal√≠a para una categor√≠a (excluyendo 'good').
        Solo incluye tipos que tienen ground truth disponible.
        """
        import os
        test_path = os.path.join(self.dataset_path, category, 'test')
        gt_path = os.path.join(self.dataset_path, category, 'ground_truth')
        
        anomaly_types = []
        if os.path.isdir(test_path):
            for folder in sorted(os.listdir(test_path)):
                folder_path = os.path.join(test_path, folder)
                gt_folder_path = os.path.join(gt_path, folder)
                # Solo incluir si no es 'good' y tiene ground truth
                if os.path.isdir(folder_path) and folder != 'good':
                    if os.path.isdir(gt_folder_path):
                        anomaly_types.append(folder)
        return anomaly_types
    
    def _load_good_images(self, category: str) -> List[Image.Image]:
        """Carga im√°genes 'good' del training set para una categor√≠a."""
        import os
        good_path = os.path.join(self.dataset_path, category, 'train', 'good')
        return load_images_from_folder(good_path, n_images=self.n_good_images)
    
    def _build_memory_bank(self, category: str, verbose: bool = True) -> MemoryBankDetector:
        """Construye el memory bank para una categor√≠a."""
        if verbose:
            print(f"\nüì¶ Construyendo Memory Bank para '{category}'...")
        
        good_images = self._load_good_images(category)
        
        detector = MemoryBankDetector(
            extractor=self.extractor,
            k=self.k,
            coreset_ratio=self.coreset_ratio
        )
        detector.build_memory_bank(good_images, verbose=verbose)
        
        return detector
    
    def evaluate_category(
        self,
        category: str,
        detector: MemoryBankDetector = None,
        verbose: bool = True
    ) -> dict:
        """
        Eval√∫a todas las anomal√≠as de una categor√≠a.
        
        Args:
            category: Nombre de la categor√≠a
            detector: Memory bank detector (si None, se construye uno nuevo)
            verbose: Si True, muestra progreso
            
        Returns:
            dict con m√©tricas por anomal√≠a y resumen de la categor√≠a
        """
        import os
        
        if detector is None:
            detector = self._build_memory_bank(category, verbose)
        
        anomaly_types = self._get_anomaly_types(category)
        
        if verbose:
            print(f"\nüîç Evaluando {len(anomaly_types)} tipos de anomal√≠a en '{category}'")
        
        category_results = {
            'category': category,
            'anomaly_results': {},
            'all_metrics': [],
            'summary': {}
        }
        
        for anomaly_type in anomaly_types:
            test_folder = os.path.join(self.dataset_path, category, 'test', anomaly_type)
            gt_folder = os.path.join(self.dataset_path, category, 'ground_truth', anomaly_type)
            
            if not os.path.isdir(gt_folder):
                if verbose:
                    print(f"  ‚ö†Ô∏è Sin ground truth para {anomaly_type}, saltando...")
                continue
            
            # Cargar pares test/ground_truth
            test_data = load_test_with_ground_truth(
                test_folder=test_folder,
                gt_folder=gt_folder,
                n_images=None  # Todas las im√°genes
            )
            
            if len(test_data) == 0:
                if verbose:
                    print(f"  ‚ö†Ô∏è No hay datos para {anomaly_type}")
                continue
            
            if verbose:
                print(f"\n  üìç Tipo: {anomaly_type} ({len(test_data)} im√°genes)")
            
            anomaly_metrics = []
            
            for test_img, gt_mask, filename in test_data:
                # Calcular mapa de anomal√≠a
                amap, amap_smooth, score = detector.compute_anomaly_map(test_img)
                
                # Evaluar m√©tricas
                metrics = self.evaluator.evaluate(amap_smooth, gt_mask)
                metrics['filename'] = filename
                metrics['category'] = category
                metrics['anomaly_type'] = anomaly_type
                metrics['image_score'] = score
                
                anomaly_metrics.append(metrics)
            
            # Calcular promedios para este tipo de anomal√≠a
            anomaly_summary = self._compute_summary(anomaly_metrics)
            anomaly_summary['n_images'] = len(anomaly_metrics)
            
            category_results['anomaly_results'][anomaly_type] = {
                'metrics': anomaly_metrics,
                'summary': anomaly_summary
            }
            category_results['all_metrics'].extend(anomaly_metrics)
            
            if verbose:
                print(f"     IoU: {anomaly_summary['IoU']:.4f} | "
                      f"Dice: {anomaly_summary['Dice']:.4f} | "
                      f"F1: {anomaly_summary['F1']:.4f} | "
                      f"AU-PRO: {anomaly_summary['AU-PRO']:.4f}")
        
        # Resumen de la categor√≠a completa
        if category_results['all_metrics']:
            category_results['summary'] = self._compute_summary(category_results['all_metrics'])
            category_results['summary']['n_images'] = len(category_results['all_metrics'])
            category_results['summary']['n_anomaly_types'] = len(anomaly_types)
        
        return category_results
    
    def _compute_summary(self, metrics_list: List[dict]) -> dict:
        """Calcula m√©tricas promedio de una lista de m√©tricas."""
        if not metrics_list:
            return {}
        
        summary = {}
        metric_keys = ['IoU', 'Dice', 'Precision', 'Recall', 'F1', 'PRO', 'AU-PRO']
        
        for key in metric_keys:
            values = [m[key] for m in metrics_list if key in m]
            if values:
                summary[key] = np.mean(values)
                summary[f'{key}_std'] = np.std(values)
        
        return summary
    
    def evaluate_all(
        self,
        categories: List[str] = None,
        verbose: bool = True,
        save_results: bool = True,
        output_path: str = None
    ) -> dict:
        """
        Eval√∫a todas las categor√≠as (o las especificadas).
        
        Args:
            categories: Lista de categor√≠as a evaluar (None = todas)
            verbose: Si True, muestra progreso detallado
            save_results: Si True, guarda resultados en JSON
            output_path: Ruta para guardar resultados (si save_results=True)
            
        Returns:
            dict con resultados completos: por categor√≠a, por anomal√≠a y globales
        """
        import os
        import json
        from datetime import datetime
        
        if categories is None:
            categories = self.available_categories
        
        print("\n" + "=" * 80)
        print("üöÄ EVALUACI√ìN COMPLETA DEL DATASET MVTEC AD")
        print("=" * 80)
        print(f"   Modelo: {self.model_path}")
        print(f"   Capa: {self.layer_idx}")
        print(f"   k-NN: k={self.k}")
        print(f"   Umbral: {self.threshold}")
        print(f"   Categor√≠as: {len(categories)}")
        print("=" * 80)
        
        all_results = {
            'config': {
                'model_path': self.model_path,
                'layer_idx': self.layer_idx,
                'k': self.k,
                'coreset_ratio': self.coreset_ratio,
                'threshold': self.threshold,
                'auto_normalize': self.auto_normalize,
                'n_good_images': self.n_good_images,
                'timestamp': datetime.now().isoformat()
            },
            'category_results': {},
            'global_metrics': [],
            'summary_by_category': {},
            'summary_by_anomaly_type': {},
            'global_summary': {}
        }
        
        for i, category in enumerate(categories, 1):
            print(f"\n{'='*80}")
            print(f"üìÇ [{i}/{len(categories)}] Categor√≠a: {category.upper()}")
            print("=" * 80)
            
            try:
                cat_results = self.evaluate_category(category, verbose=verbose)
                all_results['category_results'][category] = cat_results
                all_results['global_metrics'].extend(cat_results['all_metrics'])
                all_results['summary_by_category'][category] = cat_results['summary']
                
                # Agregar m√©tricas por tipo de anomal√≠a al resumen global
                for anomaly_type, anom_data in cat_results['anomaly_results'].items():
                    key = f"{category}/{anomaly_type}"
                    all_results['summary_by_anomaly_type'][key] = anom_data['summary']
                
            except Exception as e:
                print(f"  ‚ùå Error procesando {category}: {e}")
                import traceback
                traceback.print_exc()
        
        # Calcular m√©tricas globales
        if all_results['global_metrics']:
            all_results['global_summary'] = self._compute_summary(all_results['global_metrics'])
            all_results['global_summary']['n_total_images'] = len(all_results['global_metrics'])
            all_results['global_summary']['n_categories'] = len(categories)
        
        # Imprimir resumen final
        self._print_final_summary(all_results)
        
        # Guardar resultados
        if save_results:
            if output_path is None:
                output_path = os.path.join(
                    self.dataset_path, '..', 
                    f'evaluation_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
                )
            self._save_results(all_results, output_path)
        
        return all_results
    
    def _print_final_summary(self, results: dict):
        """Imprime el resumen final de la evaluaci√≥n."""
        print("\n" + "=" * 80)
        print("üìä RESUMEN FINAL DE EVALUACI√ìN")
        print("=" * 80)
        
        # Resumen por categor√≠a
        print("\nüè∑Ô∏è  M√âTRICAS POR CATEGOR√çA:")
        print("-" * 80)
        print(f"{'Categor√≠a':<15} {'IoU':<10} {'Dice':<10} {'F1':<10} {'Precision':<10} {'Recall':<10} {'AU-PRO':<10}")
        print("-" * 80)
        
        for category, summary in results['summary_by_category'].items():
            if summary:
                print(f"{category:<15} "
                      f"{summary.get('IoU', 0):<10.4f} "
                      f"{summary.get('Dice', 0):<10.4f} "
                      f"{summary.get('F1', 0):<10.4f} "
                      f"{summary.get('Precision', 0):<10.4f} "
                      f"{summary.get('Recall', 0):<10.4f} "
                      f"{summary.get('AU-PRO', 0):<10.4f}")
        
        # Resumen global
        print("\n" + "=" * 80)
        print("üåç M√âTRICAS GLOBALES:")
        print("=" * 80)
        gs = results['global_summary']
        if gs:
            print(f"   Total im√°genes evaluadas: {gs.get('n_total_images', 0)}")
            print(f"   Total categor√≠as: {gs.get('n_categories', 0)}")
            print()
            print(f"   üìà M√©tricas Pixel-Level:")
            print(f"      IoU:       {gs.get('IoU', 0):.4f} (¬± {gs.get('IoU_std', 0):.4f})")
            print(f"      Dice:      {gs.get('Dice', 0):.4f} (¬± {gs.get('Dice_std', 0):.4f})")
            print(f"      Precision: {gs.get('Precision', 0):.4f} (¬± {gs.get('Precision_std', 0):.4f})")
            print(f"      Recall:    {gs.get('Recall', 0):.4f} (¬± {gs.get('Recall_std', 0):.4f})")
            print(f"      F1:        {gs.get('F1', 0):.4f} (¬± {gs.get('F1_std', 0):.4f})")
            print()
            print(f"   üìà M√©tricas Region-Level:")
            print(f"      PRO:       {gs.get('PRO', 0):.4f} (¬± {gs.get('PRO_std', 0):.4f})")
            print(f"      AU-PRO:    {gs.get('AU-PRO', 0):.4f} (¬± {gs.get('AU-PRO_std', 0):.4f})")
    
    def _save_results(self, results: dict, output_path: str):
        """Guarda los resultados en formato JSON."""
        import json
        import os
        
        # Eliminar datos internos que no son serializables
        results_clean = self._clean_results_for_json(results)
        
        os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(results_clean, f, indent=2, ensure_ascii=False)
        
        print(f"\nüíæ Resultados guardados en: {output_path}")
    
    def _clean_results_for_json(self, results: dict) -> dict:
        """Limpia los resultados para serializaci√≥n JSON."""
        import copy
        
        def clean_value(v):
            if isinstance(v, (np.floating, np.integer)):
                return float(v)
            elif isinstance(v, np.ndarray):
                return v.tolist()
            elif isinstance(v, dict):
                return {k: clean_value(val) for k, val in v.items() 
                       if not k.startswith('_')}
            elif isinstance(v, list):
                return [clean_value(item) for item in v]
            else:
                return v
        
        return clean_value(results)
    
    def generate_report(
        self,
        results: dict,
        output_path: str = None
    ) -> str:
        """
        Genera un reporte detallado en formato Markdown.
        
        Args:
            results: Resultados de evaluate_all()
            output_path: Ruta para guardar el reporte (opcional)
            
        Returns:
            Contenido del reporte en Markdown
        """
        from datetime import datetime
        import os
        
        report = []
        report.append("# üìä Reporte de Evaluaci√≥n MVTec AD")
        report.append(f"\n**Fecha:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        
        # Configuraci√≥n
        config = results.get('config', {})
        report.append("## ‚öôÔ∏è Configuraci√≥n\n")
        report.append(f"- **Modelo:** `{config.get('model_path', 'N/A')}`")
        report.append(f"- **Capa DINOv2:** {config.get('layer_idx', 'N/A')}")
        report.append(f"- **k-NN (k):** {config.get('k', 'N/A')}")
        report.append(f"- **Umbral:** {config.get('threshold', 'N/A')}")
        report.append(f"- **Im√°genes 'good' para Memory Bank:** {config.get('n_good_images', 'Todas')}")
        report.append("")
        
        # Resumen global
        gs = results.get('global_summary', {})
        report.append("## üåç M√©tricas Globales\n")
        report.append(f"- **Total im√°genes evaluadas:** {gs.get('n_total_images', 0)}")
        report.append(f"- **Total categor√≠as:** {gs.get('n_categories', 0)}")
        report.append("")
        report.append("| M√©trica | Valor | Desv. Est. |")
        report.append("|---------|-------|------------|")
        for metric in ['IoU', 'Dice', 'Precision', 'Recall', 'F1', 'PRO', 'AU-PRO']:
            val = gs.get(metric, 0)
            std = gs.get(f'{metric}_std', 0)
            report.append(f"| {metric} | {val:.4f} | ¬± {std:.4f} |")
        report.append("")
        
        # Tabla por categor√≠a
        report.append("## üè∑Ô∏è M√©tricas por Categor√≠a\n")
        report.append("| Categor√≠a | IoU | Dice | F1 | Precision | Recall | AU-PRO |")
        report.append("|-----------|-----|------|----|-----------|----- --|--------|")
        
        for category, summary in results.get('summary_by_category', {}).items():
            if summary:
                report.append(
                    f"| {category} | "
                    f"{summary.get('IoU', 0):.4f} | "
                    f"{summary.get('Dice', 0):.4f} | "
                    f"{summary.get('F1', 0):.4f} | "
                    f"{summary.get('Precision', 0):.4f} | "
                    f"{summary.get('Recall', 0):.4f} | "
                    f"{summary.get('AU-PRO', 0):.4f} |"
                )
        report.append("")
        
        # Detalle por tipo de anomal√≠a
        report.append("## üî¨ Detalle por Tipo de Anomal√≠a\n")
        
        for category, cat_data in results.get('category_results', {}).items():
            report.append(f"### {category.upper()}\n")
            report.append("| Tipo de Anomal√≠a | Im√°genes | IoU | Dice | F1 | AU-PRO |")
            report.append("|------------------|----------|-----|------|----|----- --|")
            
            for anomaly_type, anom_data in cat_data.get('anomaly_results', {}).items():
                summary = anom_data.get('summary', {})
                report.append(
                    f"| {anomaly_type} | "
                    f"{summary.get('n_images', 0)} | "
                    f"{summary.get('IoU', 0):.4f} | "
                    f"{summary.get('Dice', 0):.4f} | "
                    f"{summary.get('F1', 0):.4f} | "
                    f"{summary.get('AU-PRO', 0):.4f} |"
                )
            report.append("")
        
        report_content = "\n".join(report)
        
        # Guardar si se especifica ruta
        if output_path:
            os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
            with open(output_path, 'w', encoding='utf-8') as f:
                f.write(report_content)
            print(f"\nüìù Reporte guardado en: {output_path}")
        
        return report_content


class MVTecDatasetEvaluator:
    """
    Evaluador completo del dataset MVTec AD.
    
    Itera sobre todas las categor√≠as y tipos de anomal√≠a,
    calcula m√©tricas globales, por clase y por anomal√≠a.
    
    Args:
        dataset_path: Ruta base al dataset MVTec AD (contiene las 15 carpetas de categor√≠as)
        model_path: Ruta al modelo DINOv2
        layer_idx: √çndice de la capa a usar (-1 = √∫ltima)
        n_good_images: N√∫mero de im√°genes "good" para memory bank (None = todas)
        k: N√∫mero de vecinos para k-NN
        coreset_ratio: Ratio de subsampling del memory bank
        threshold: Umbral para binarizaci√≥n
        auto_normalize: Si True, normaliza mapas de anomal√≠a autom√°ticamente
    """
    
    # Lista de las 15 categor√≠as de MVTec AD
    CATEGORIES = [
        'bottle', 'cable', 'capsule', 'carpet', 'grid',
        'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
        'tile', 'toothbrush', 'transistor', 'wood', 'zipper'
    ]
    
    def __init__(
        self,
        dataset_path: str,
        model_path: str,
        layer_idx: int = -1,
        n_good_images: int = None,
        k: int = 1,
        coreset_ratio: float = 1.0,
        threshold: float = 0.5,
        auto_normalize: bool = True,
        device: str = None
    ):
        import os
        
        self.dataset_path = dataset_path
        self.model_path = model_path
        self.layer_idx = layer_idx
        self.n_good_images = n_good_images
        self.k = k
        self.coreset_ratio = coreset_ratio
        self.threshold = threshold
        self.auto_normalize = auto_normalize
        self.device = device
        
        # Crear extractor de features
        self.extractor = DINOv2FeatureExtractor(
            model_path=model_path,
            layer_idx=layer_idx,
            device=device
        )
        
        # Evaluador de m√©tricas
        self.evaluator = AnomalyEvaluator(
            threshold=threshold,
            auto_normalize=auto_normalize
        )
        
        # Descubrir categor√≠as disponibles
        self.available_categories = self._discover_categories()
        print(f"üîç Encontradas {len(self.available_categories)} categor√≠as en {dataset_path}")
    
    def _discover_categories(self) -> List[str]:
        """Descubre las categor√≠as disponibles en el dataset."""
        import os
        categories = []
        for cat in self.CATEGORIES:
            cat_path = os.path.join(self.dataset_path, cat)
            if os.path.isdir(cat_path):
                categories.append(cat)
        return categories
    
    def _get_anomaly_types(self, category: str) -> List[str]:
        """
        Obtiene los tipos de anomal√≠a para una categor√≠a (excluyendo 'good').
        Solo incluye tipos que tienen ground truth disponible.
        """
        import os
        test_path = os.path.join(self.dataset_path, category, 'test')
        gt_path = os.path.join(self.dataset_path, category, 'ground_truth')
        
        anomaly_types = []
        if os.path.isdir(test_path):
            for folder in sorted(os.listdir(test_path)):
                folder_path = os.path.join(test_path, folder)
                gt_folder_path = os.path.join(gt_path, folder)
                # Solo incluir si no es 'good' y tiene ground truth
                if os.path.isdir(folder_path) and folder != 'good':
                    if os.path.isdir(gt_folder_path):
                        anomaly_types.append(folder)
        return anomaly_types
    
    def _load_good_images(self, category: str) -> List[Image.Image]:
        """Carga im√°genes 'good' del training set para una categor√≠a."""
        import os
        good_path = os.path.join(self.dataset_path, category, 'train', 'good')
        return load_images_from_folder(good_path, n_images=self.n_good_images)
    
    def _build_memory_bank(self, category: str, verbose: bool = True) -> MemoryBankDetector:
        """Construye el memory bank para una categor√≠a."""
        if verbose:
            print(f"\nüì¶ Construyendo Memory Bank para '{category}'...")
        
        good_images = self._load_good_images(category)
        
        detector = MemoryBankDetector(
            extractor=self.extractor,
            k=self.k,
            coreset_ratio=self.coreset_ratio
        )
        detector.build_memory_bank(good_images, verbose=verbose)
        
        return detector
    
    def evaluate_category(
        self,
        category: str,
        detector: MemoryBankDetector = None,
        verbose: bool = True
    ) -> dict:
        """
        Eval√∫a todas las anomal√≠as de una categor√≠a.
        
        Args:
            category: Nombre de la categor√≠a
            detector: Memory bank detector (si None, se construye uno nuevo)
            verbose: Si True, muestra progreso
            
        Returns:
            dict con m√©tricas por anomal√≠a y resumen de la categor√≠a
        """
        import os
        
        if detector is None:
            detector = self._build_memory_bank(category, verbose)
        
        anomaly_types = self._get_anomaly_types(category)
        
        if verbose:
            print(f"\nüîç Evaluando {len(anomaly_types)} tipos de anomal√≠a en '{category}'")
        
        category_results = {
            'category': category,
            'anomaly_results': {},
            'all_metrics': [],
            'summary': {}
        }
        
        for anomaly_type in anomaly_types:
            test_folder = os.path.join(self.dataset_path, category, 'test', anomaly_type)
            gt_folder = os.path.join(self.dataset_path, category, 'ground_truth', anomaly_type)
            
            if not os.path.isdir(gt_folder):
                if verbose:
                    print(f"  ‚ö†Ô∏è Sin ground truth para {anomaly_type}, saltando...")
                continue
            
            # Cargar pares test/ground_truth
            test_data = load_test_with_ground_truth(
                test_folder=test_folder,
                gt_folder=gt_folder,
                n_images=None  # Todas las im√°genes
            )
            
            if len(test_data) == 0:
                if verbose:
                    print(f"  ‚ö†Ô∏è No hay datos para {anomaly_type}")
                continue
            
            if verbose:
                print(f"\n  üìç Tipo: {anomaly_type} ({len(test_data)} im√°genes)")
            
            anomaly_metrics = []
            
            for test_img, gt_mask, filename in test_data:
                # Calcular mapa de anomal√≠a
                amap, amap_smooth, score = detector.compute_anomaly_map(test_img)
                
                # Evaluar m√©tricas
                metrics = self.evaluator.evaluate(amap_smooth, gt_mask)
                metrics['filename'] = filename
                metrics['category'] = category
                metrics['anomaly_type'] = anomaly_type
                metrics['image_score'] = score
                
                anomaly_metrics.append(metrics)
            
            # Calcular promedios para este tipo de anomal√≠a
            anomaly_summary = self._compute_summary(anomaly_metrics)
            anomaly_summary['n_images'] = len(anomaly_metrics)
            
            category_results['anomaly_results'][anomaly_type] = {
                'metrics': anomaly_metrics,
                'summary': anomaly_summary
            }
            category_results['all_metrics'].extend(anomaly_metrics)
            
            if verbose:
                print(f"     IoU: {anomaly_summary['IoU']:.4f} | "
                      f"Dice: {anomaly_summary['Dice']:.4f} | "
                      f"F1: {anomaly_summary['F1']:.4f} | "
                      f"AU-PRO: {anomaly_summary['AU-PRO']:.4f}")
        
        # Resumen de la categor√≠a completa
        if category_results['all_metrics']:
            category_results['summary'] = self._compute_summary(category_results['all_metrics'])
            category_results['summary']['n_images'] = len(category_results['all_metrics'])
            category_results['summary']['n_anomaly_types'] = len(anomaly_types)
        
        return category_results
    
    def _compute_summary(self, metrics_list: List[dict]) -> dict:
        """Calcula m√©tricas promedio de una lista de m√©tricas."""
        if not metrics_list:
            return {}
        
        summary = {}
        metric_keys = ['IoU', 'Dice', 'Precision', 'Recall', 'F1', 'PRO', 'AU-PRO']
        
        for key in metric_keys:
            values = [m[key] for m in metrics_list if key in m]
            if values:
                summary[key] = np.mean(values)
                summary[f'{key}_std'] = np.std(values)
        
        return summary
    
    def evaluate_all(
        self,
        categories: List[str] = None,
        verbose: bool = True,
        save_results: bool = True,
        output_path: str = None
    ) -> dict:
        """
        Eval√∫a todas las categor√≠as (o las especificadas).
        
        Args:
            categories: Lista de categor√≠as a evaluar (None = todas)
            verbose: Si True, muestra progreso detallado
            save_results: Si True, guarda resultados en JSON
            output_path: Ruta para guardar resultados (si save_results=True)
            
        Returns:
            dict con resultados completos: por categor√≠a, por anomal√≠a y globales
        """
        import os
        import json
        from datetime import datetime
        
        if categories is None:
            categories = self.available_categories
        
        print("\n" + "=" * 80)
        print("üöÄ EVALUACI√ìN COMPLETA DEL DATASET MVTEC AD")
        print("=" * 80)
        print(f"   Modelo: {self.model_path}")
        print(f"   Capa: {self.layer_idx}")
        print(f"   k-NN: k={self.k}")
        print(f"   Umbral: {self.threshold}")
        print(f"   Categor√≠as: {len(categories)}")
        print("=" * 80)
        
        all_results = {
            'config': {
                'model_path': self.model_path,
                'layer_idx': self.layer_idx,
                'k': self.k,
                'coreset_ratio': self.coreset_ratio,
                'threshold': self.threshold,
                'auto_normalize': self.auto_normalize,
                'n_good_images': self.n_good_images,
                'timestamp': datetime.now().isoformat()
            },
            'category_results': {},
            'global_metrics': [],
            'summary_by_category': {},
            'summary_by_anomaly_type': {},
            'global_summary': {}
        }
        
        for i, category in enumerate(categories, 1):
            print(f"\n{'='*80}")
            print(f"üìÇ [{i}/{len(categories)}] Categor√≠a: {category.upper()}")
            print("=" * 80)
            
            try:
                cat_results = self.evaluate_category(category, verbose=verbose)
                all_results['category_results'][category] = cat_results
                all_results['global_metrics'].extend(cat_results['all_metrics'])
                all_results['summary_by_category'][category] = cat_results['summary']
                
                # Agregar m√©tricas por tipo de anomal√≠a al resumen global
                for anomaly_type, anom_data in cat_results['anomaly_results'].items():
                    key = f"{category}/{anomaly_type}"
                    all_results['summary_by_anomaly_type'][key] = anom_data['summary']
                
            except Exception as e:
                print(f"  ‚ùå Error procesando {category}: {e}")
                import traceback
                traceback.print_exc()
        
        # Calcular m√©tricas globales
        if all_results['global_metrics']:
            all_results['global_summary'] = self._compute_summary(all_results['global_metrics'])
            all_results['global_summary']['n_total_images'] = len(all_results['global_metrics'])
            all_results['global_summary']['n_categories'] = len(categories)
        
        # Imprimir resumen final
        self._print_final_summary(all_results)
        
        # Guardar resultados
        if save_results:
            if output_path is None:
                output_path = os.path.join(
                    self.dataset_path, '..', 
                    f'evaluation_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
                )
            self._save_results(all_results, output_path)
        
        return all_results
    
    def _print_final_summary(self, results: dict):
        """Imprime el resumen final de la evaluaci√≥n."""
        print("\n" + "=" * 80)
        print("üìä RESUMEN FINAL DE EVALUACI√ìN")
        print("=" * 80)
        
        # Resumen por categor√≠a
        print("\nüè∑Ô∏è  M√âTRICAS POR CATEGOR√çA:")
        print("-" * 80)
        print(f"{'Categor√≠a':<15} {'IoU':<10} {'Dice':<10} {'F1':<10} {'Precision':<10} {'Recall':<10} {'AU-PRO':<10}")
        print("-" * 80)
        
        for category, summary in results['summary_by_category'].items():
            if summary:
                print(f"{category:<15} "
                      f"{summary.get('IoU', 0):<10.4f} "
                      f"{summary.get('Dice', 0):<10.4f} "
                      f"{summary.get('F1', 0):<10.4f} "
                      f"{summary.get('Precision', 0):<10.4f} "
                      f"{summary.get('Recall', 0):<10.4f} "
                      f"{summary.get('AU-PRO', 0):<10.4f}")
        
        # Resumen global
        print("\n" + "=" * 80)
        print("üåç M√âTRICAS GLOBALES:")
        print("=" * 80)
        gs = results['global_summary']
        if gs:
            print(f"   Total im√°genes evaluadas: {gs.get('n_total_images', 0)}")
            print(f"   Total categor√≠as: {gs.get('n_categories', 0)}")
            print()
            print(f"   üìà M√©tricas Pixel-Level:")
            print(f"      IoU:       {gs.get('IoU', 0):.4f} (¬± {gs.get('IoU_std', 0):.4f})")
            print(f"      Dice:      {gs.get('Dice', 0):.4f} (¬± {gs.get('Dice_std', 0):.4f})")
            print(f"      Precision: {gs.get('Precision', 0):.4f} (¬± {gs.get('Precision_std', 0):.4f})")
            print(f"      Recall:    {gs.get('Recall', 0):.4f} (¬± {gs.get('Recall_std', 0):.4f})")
            print(f"      F1:        {gs.get('F1', 0):.4f} (¬± {gs.get('F1_std', 0):.4f})")
            print()
            print(f"   üìà M√©tricas Region-Level:")
            print(f"      PRO:       {gs.get('PRO', 0):.4f} (¬± {gs.get('PRO_std', 0):.4f})")
            print(f"      AU-PRO:    {gs.get('AU-PRO', 0):.4f} (¬± {gs.get('AU-PRO_std', 0):.4f})")
    
    def _save_results(self, results: dict, output_path: str):
        """Guarda los resultados en formato JSON."""
        import json
        import os
        
        # Eliminar datos internos que no son serializables
        results_clean = self._clean_results_for_json(results)
        
        os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(results_clean, f, indent=2, ensure_ascii=False)
        
        print(f"\nüíæ Resultados guardados en: {output_path}")
    
    def _clean_results_for_json(self, results: dict) -> dict:
        """Limpia los resultados para serializaci√≥n JSON."""
        import copy
        
        def clean_value(v):
            if isinstance(v, (np.floating, np.integer)):
                return float(v)
            elif isinstance(v, np.ndarray):
                return v.tolist()
            elif isinstance(v, dict):
                return {k: clean_value(val) for k, val in v.items() 
                       if not k.startswith('_')}
            elif isinstance(v, list):
                return [clean_value(item) for item in v]
            else:
                return v
        
        return clean_value(results)
    
    def generate_report(
        self,
        results: dict,
        output_path: str = None
    ) -> str:
        """
        Genera un reporte detallado en formato Markdown.
        
        Args:
            results: Resultados de evaluate_all()
            output_path: Ruta para guardar el reporte (opcional)
            
        Returns:
            Contenido del reporte en Markdown
        """
        from datetime import datetime
        import os
        
        report = []
        report.append("# üìä Reporte de Evaluaci√≥n MVTec AD")
        report.append(f"\n**Fecha:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        
        # Configuraci√≥n
        config = results.get('config', {})
        report.append("## ‚öôÔ∏è Configuraci√≥n\n")
        report.append(f"- **Modelo:** `{config.get('model_path', 'N/A')}`")
        report.append(f"- **Capa DINOv2:** {config.get('layer_idx', 'N/A')}")
        report.append(f"- **k-NN (k):** {config.get('k', 'N/A')}")
        report.append(f"- **Umbral:** {config.get('threshold', 'N/A')}")
        report.append(f"- **Im√°genes 'good' para Memory Bank:** {config.get('n_good_images', 'Todas')}")
        report.append("")
        
        # Resumen global
        gs = results.get('global_summary', {})
        report.append("## üåç M√©tricas Globales\n")
        report.append(f"- **Total im√°genes evaluadas:** {gs.get('n_total_images', 0)}")
        report.append(f"- **Total categor√≠as:** {gs.get('n_categories', 0)}")
        report.append("")
        report.append("| M√©trica | Valor | Desv. Est. |")
        report.append("|---------|-------|------------|")
        for metric in ['IoU', 'Dice', 'Precision', 'Recall', 'F1', 'PRO', 'AU-PRO']:
            val = gs.get(metric, 0)
            std = gs.get(f'{metric}_std', 0)
            report.append(f"| {metric} | {val:.4f} | ¬± {std:.4f} |")
        report.append("")
        
        # Tabla por categor√≠a
        report.append("## üè∑Ô∏è M√©tricas por Categor√≠a\n")
        report.append("| Categor√≠a | IoU | Dice | F1 | Precision | Recall | AU-PRO |")
        report.append("|-----------|-----|------|----|-----------|----- --|--------|")
        
        for category, summary in results.get('summary_by_category', {}).items():
            if summary:
                report.append(
                    f"| {category} | "
                    f"{summary.get('IoU', 0):.4f} | "
                    f"{summary.get('Dice', 0):.4f} | "
                    f"{summary.get('F1', 0):.4f} | "
                    f"{summary.get('Precision', 0):.4f} | "
                    f"{summary.get('Recall', 0):.4f} | "
                    f"{summary.get('AU-PRO', 0):.4f} |"
                )
        report.append("")
        
        # Detalle por tipo de anomal√≠a
        report.append("## üî¨ Detalle por Tipo de Anomal√≠a\n")
        
        for category, cat_data in results.get('category_results', {}).items():
            report.append(f"### {category.upper()}\n")
            report.append("| Tipo de Anomal√≠a | Im√°genes | IoU | Dice | F1 | AU-PRO |")
            report.append("|------------------|----------|-----|------|----|----- --|")
            
            for anomaly_type, anom_data in cat_data.get('anomaly_results', {}).items():
                summary = anom_data.get('summary', {})
                report.append(
                    f"| {anomaly_type} | "
                    f"{summary.get('n_images', 0)} | "
                    f"{summary.get('IoU', 0):.4f} | "
                    f"{summary.get('Dice', 0):.4f} | "
                    f"{summary.get('F1', 0):.4f} | "
                    f"{summary.get('AU-PRO', 0):.4f} |"
                )
            report.append("")
        
        report_content = "\n".join(report)
        
        # Guardar si se especifica ruta
        if output_path:
            os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
            with open(output_path, 'w', encoding='utf-8') as f:
                f.write(report_content)
            print(f"\nüìù Reporte guardado en: {output_path}")
        
        return report_content


    
   
    print("=" * 80)
    print("üöÄ EVALUACI√ìN COMPLETA DEL DATASET MVTEC AD")
    print("=" * 80)
    
    # Crear evaluador
    evaluator_mvtec = MVTecDatasetEvaluator(
        dataset_path="/home/bllancao/Portafolio/mvtec_anomaly_detection/data/raw",
        model_path="/home/bllancao/Portafolio/mvtec_anomaly_detection/models/dinov2-base",
        layer_idx=-1,
        n_good_images=200,
        k=1,
        threshold=0.6
    )
    
    # Evaluar todas las categor√≠as (o las especificadas)
    results = evaluator_mvtec.evaluate_all(
        categories=["bottle", "cable", "capsule", "carpet", "grid", "hazelnut", "leather", "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper"],
        verbose=True,
        save_results=False,
        output_path=None
    )
    
    print("\n‚úÖ Evaluaci√≥n completa finalizada")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


üöÄ EVALUACI√ìN COMPLETA DEL DATASET MVTEC AD
üîç Encontradas 15 categor√≠as en /home/bllancao/Portafolio/mvtec_anomaly_detection/data/raw

üöÄ EVALUACI√ìN COMPLETA DEL DATASET MVTEC AD
   Modelo: /home/bllancao/Portafolio/mvtec_anomaly_detection/models/dinov2-base
   Capa: -1
   k-NN: k=1
   Umbral: 0.6
   Categor√≠as: 15

üìÇ [1/15] Categor√≠a: BOTTLE

üì¶ Construyendo Memory Bank para 'bottle'...
Cargadas 200 im√°genes de /home/bllancao/Portafolio/mvtec_anomaly_detection/data/raw/bottle/train/good
Procesadas 10/200 im√°genes
Procesadas 20/200 im√°genes
Procesadas 30/200 im√°genes
Procesadas 40/200 im√°genes
Procesadas 50/200 im√°genes
Procesadas 60/200 im√°genes
Procesadas 70/200 im√°genes
Procesadas 80/200 im√°genes
Procesadas 90/200 im√°genes
Procesadas 100/200 im√°genes
Procesadas 110/200 im√°genes
Procesadas 120/200 im√°genes
Procesadas 130/200 im√°genes
Procesadas 140/200 im√°genes
Procesadas 150/200 im√°genes
Procesadas 160/200 im√°genes
Procesadas 170/200 im√°genes
Proce

  au_pro = np.trapz(pro_sorted, fpr_sorted) / fpr_limit


     IoU: 0.6007 | Dice: 0.7490 | F1: 0.7490 | AU-PRO: 0.8788
Cargados 22 pares (test, ground_truth) de /home/bllancao/Portafolio/mvtec_anomaly_detection/data/raw/bottle/test/broken_small

  üìç Tipo: broken_small (22 im√°genes)
     IoU: 0.5307 | Dice: 0.6900 | F1: 0.6900 | AU-PRO: 0.8707
Cargados 21 pares (test, ground_truth) de /home/bllancao/Portafolio/mvtec_anomaly_detection/data/raw/bottle/test/contamination

  üìç Tipo: contamination (21 im√°genes)
     IoU: 0.5631 | Dice: 0.6936 | F1: 0.6936 | AU-PRO: 0.9250

üìÇ [2/15] Categor√≠a: CABLE

üì¶ Construyendo Memory Bank para 'cable'...
Cargadas 200 im√°genes de /home/bllancao/Portafolio/mvtec_anomaly_detection/data/raw/cable/train/good
Procesadas 10/200 im√°genes
Procesadas 20/200 im√°genes
Procesadas 30/200 im√°genes
Procesadas 40/200 im√°genes
Procesadas 50/200 im√°genes
Procesadas 60/200 im√°genes
Procesadas 70/200 im√°genes
Procesadas 80/200 im√°genes
Procesadas 90/200 im√°genes
Procesadas 100/200 im√°genes
Procesadas 110/