<a href="https://colab.research.google.com/github/jalevano/tfm_uoc_datascience/blob/main/01_Comparativa_HuggingFace.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
Evaluador de Segmentación de Personas - Arquitectura SOLID
==========================================================

Sistema de evaluación siguiendo principios SOLID:
- Single Responsibility: Cada clase tiene una responsabilidad específica
- Open/Closed: Extensible para nuevos modelos/métricas
- Liskov Substitution: Interfaces bien definidas
- Interface Segregation: Interfaces específicas
- Dependency Inversion: Dependencias abstraídas

Referencias:
- Kirillov et al. "Panoptic Segmentation" CVPR 2019
- Cheng et al. "Per-Pixel Classification is Not All You Need" NeurIPS 2021
- Lin et al. "Microsoft COCO: Common Objects in Context" ECCV 2014

Autor: Jesús L.
Fecha: Agosto 2025.
"""

'\nEVALUACIÓN COMPARATIVA. Mask2former vs OneFormer para Segmentación de personas.\n\nSe proporcionar análisis detallado de las ventajas de la segmentación panóptica\ncon OneFormer frente segmentación de instancias tradicional.\n\nConceptos clave:\n- Segmentación panóptica: Segmentación semántica e instancias\n- Comparación arquitectónica: Transformer-based models estado del arte\n- Evaluación académica: Métricas comprehensivas.\n\nAutor: Jesús L.\nTrabajo: Evaluación comparativa de técnicas de segmentación.\nFecha: Agosto 2025.\n\nReferencias Técnicas:\n- Cheng et al. "Masked-attention Mask Transformer for Universal Image Segmentation" (Mask2Former)\n- Jain et al. "OneFormer: One Transformer to Rule Universal Image Segmentation" (OneFormer)\n- Kirillov et al. "Panoptic Segmentation" (Conceptos fundamentales)\n'

In [2]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.8.1-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading torchmetrics-1.8.1-py3-none-any.whl (982 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.0/983.0 kB[0m [31m41.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.15.2 torchmetrics-1.8.1


In [15]:
import os
import json
import torch
import numpy as np
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from datetime import datetime
from typing import Dict, List, Optional, Protocol, Union, Tuple
from pathlib import Path
import logging
from PIL import Image
import cv2
from tqdm.auto import tqdm

# Transformers imports
from transformers import (
    AutoImageProcessor,
    AutoModelForUniversalSegmentation,
    Mask2FormerImageProcessor,
    Mask2FormerForUniversalSegmentation,
    OneFormerImageProcessor,
    OneFormerForUniversalSegmentation
)

# Metrics imports
from sklearn.metrics import precision_recall_fscore_support
import warnings
warnings.filterwarnings('ignore')

In [14]:
# Metrics imports
from sklearn.metrics import precision_recall_fscore_support
import warnings
warnings.filterwarnings('ignore')

In [16]:
# ============================================================================
# PROTOCOLS Y INTERFACES
# ============================================================================

class ModelInterface(Protocol):
    """Protocolo que define la interfaz para modelos de segmentación."""

    def predict(self, image: Image.Image) -> Dict:
        """Realizar predicción en imagen."""
        ...

    def get_model_info(self) -> Dict:
        """Obtener información del modelo."""
        ...


class MetricsCalculatorInterface(Protocol):
    """Protocolo para calculadoras de métricas."""

    def calculate_metrics(self, prediction: np.ndarray, ground_truth: np.ndarray) -> Dict[str, float]:
        """Calcular métricas de evaluación."""
        ...


class DataLoaderInterface(Protocol):
    """Protocolo para cargadores de datos."""

    def load_images(self) -> List[str]:
        """Cargar lista de rutas de imágenes."""
        ...

    def load_ground_truth(self, image_path: str) -> Optional[np.ndarray]:
        """Cargar ground truth para imagen específica."""
        ...


class ResultsStorageInterface(Protocol):
    """Protocolo para almacenamiento de resultados."""

    def save_results(self, results: Dict, filepath: str) -> str:
        """Guardar resultados en formato especificado."""
        ...


In [17]:
# ============================================================================
# DATACLASSES Y CONFIGURACIONES
# ============================================================================

@dataclass(frozen=True)
class EvaluationConfig:
    """Configuración inmutable para evaluación de segmentación."""
    # Configuración del modelo
    model_id: str
    model_name: str
    confidence_threshold: float = 0.7

    # Configuración de datos
    images_directory: str = ""
    ground_truth_directory: str = ""
    output_directory: str = "evaluation_results"

    # Configuración de procesamiento
    target_size: Tuple[int, int] = (512, 512)
    max_images: Optional[int] = None
    device: str = "auto"

    # Configuración de evaluación
    random_seed: int = 42
    evaluation_protocol: str = "COCO"
    save_predictions: bool = True
    calculate_panoptic_metrics: bool = True

    def __post_init__(self):
        """Validación post-inicialización."""
        if not (0.0 <= self.confidence_threshold <= 1.0):
            raise ValueError(f"confidence_threshold debe estar en [0.0, 1.0], recibido: {self.confidence_threshold}")

        if self.max_images is not None and self.max_images <= 0:
            raise ValueError(f"max_images debe ser positivo, recibido: {self.max_images}")

        if not os.path.exists(self.images_directory):
            raise FileNotFoundError(f"Directorio de imágenes no encontrado: {self.images_directory}")

In [18]:
@dataclass
class ImageResult:
    """Resultado de evaluación para una imagen individual."""
    image_name: str
    image_path: str
    success: bool
    processing_time_seconds: float
    original_size: Tuple[int, int]
    processed_size: Tuple[int, int]
    has_ground_truth: bool
    has_person_prediction: bool
    metrics: Dict[str, float]
    prediction_metadata: Dict
    error_message: str = ""


@dataclass
class ModelMetadata:
    """Metadatos del modelo evaluado."""
    model_id: str
    model_name: str
    architecture_type: str
    supports_panoptic: bool
    parameter_count: int
    device_used: str
    pytorch_version: str
    transformers_version: str


@dataclass
class EvaluationMetadata:
    """Metadatos completos de la evaluación."""
    model_metadata: ModelMetadata
    evaluation_config: EvaluationConfig
    execution_info: Dict
    dataset_statistics: Dict

In [None]:
# ============================================================================
# CALCULADORA DE MÉTRICAS COMPREHENSIVAS
# ============================================================================

class ComprehensiveMetricsCalculator:
    """
    Calculadora de métricas comprehensivas para segmentación de personas.

    Implementa métricas estándar:
    - Métricas básicas: IoU, Dice, Precision, Recall, F1
    - Métricas de contorno: Boundary IoU, distancias
    - Métricas de área y conectividad
    """

    def __init__(self, person_class_id: int = 1):
        """
        Inicializar calculadora de métricas.

        Args:
            person_class_id: ID de clase para 'person' según dataset (COCO=1)
        """
        self.person_class_id = person_class_id
        self.epsilon = 1e-7  # Para estabilidad numérica

    def calculate_comprehensive_metrics(self, prediction: np.ndarray,
                                      ground_truth: np.ndarray) -> Dict[str, float]:
        """
        Calcular conjunto comprehensivo de métricas.

        Args:
            prediction: Máscara predicha binaria (0, 1)
            ground_truth: Máscara de ground truth binaria (0, 1)

        Returns:
            Dict con métricas completas
        """
        try:
            # Asegurar compatibilidad de dimensiones
            prediction, ground_truth = self._ensure_compatible_shapes(prediction, ground_truth)

            # Calcular métricas por categoría
            basic_metrics = self._calculate_basic_metrics(prediction, ground_truth)
            boundary_metrics = self._calculate_boundary_metrics(prediction, ground_truth)
            area_metrics = self._calculate_area_metrics(prediction, ground_truth)
            connectivity_metrics = self._calculate_connectivity_metrics(prediction, ground_truth)

            # Combinar todas las métricas
            comprehensive_metrics = {
                **basic_metrics,
                **boundary_metrics,
                **area_metrics,
                **connectivity_metrics
            }

            return comprehensive_metrics

        except Exception as e:
            logging.warning(f"Error calculando métricas: {str(e)}")
            return self._get_empty_metrics()

    def _ensure_compatible_shapes(self, pred: np.ndarray, gt: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Asegurar que predicción y ground truth tengan formas compatibles."""
        if pred.shape != gt.shape:
            pred = cv2.resize(pred.astype(np.uint8), (gt.shape[1], gt.shape[0]),
                            interpolation=cv2.INTER_NEAREST)

        # Asegurar valores binarios
        pred = (pred > 0.5).astype(np.uint8)
        gt = (gt > 0.5).astype(np.uint8)

        return pred, gt

    def _calculate_basic_metrics(self, pred: np.ndarray, gt: np.ndarray) -> Dict[str, float]:
        """Calcular métricas básicas de segmentación."""
        pred_flat = pred.flatten()
        gt_flat = gt.flatten()

        # Componentes de matriz de confusión
        tp = np.sum((pred_flat == 1) & (gt_flat == 1))
        fp = np.sum((pred_flat == 1) & (gt_flat == 0))
        tn = np.sum((pred_flat == 0) & (gt_flat == 0))
        fn = np.sum((pred_flat == 0) & (gt_flat == 1))

        metrics = {}

        # Intersection over Union (Jaccard Index)
        intersection = tp
        union = tp + fp + fn
        metrics['iou'] = float(intersection / (union + self.epsilon))

        # Dice Coefficient
        metrics['dice'] = float(2 * tp / (2 * tp + fp + fn + self.epsilon))

        # Precision (Positive Predictive Value)
        metrics['precision'] = float(tp / (tp + fp + self.epsilon))

        # Recall (Sensitivity)
        metrics['recall'] = float(tp / (tp + fn + self.epsilon))

        # Specificity (True Negative Rate)
        metrics['specificity'] = float(tn / (tn + fp + self.epsilon))

        # F1-Score
        metrics['f1_score'] = float(2 * metrics['precision'] * metrics['recall'] /
                                   (metrics['precision'] + metrics['recall'] + self.epsilon))

        # Accuracy
        metrics['pixel_accuracy'] = float((tp + tn) / (tp + tn + fp + fn + self.epsilon))

        # Balanced Accuracy
        sensitivity = metrics['recall']
        specificity = metrics['specificity']
        metrics['balanced_accuracy'] = float((sensitivity + specificity) / 2)

        # Matthews Correlation Coefficient
        numerator = (tp * tn) - (fp * fn)
        denominator = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
        metrics['mcc'] = float(numerator / (denominator + self.epsilon))

        # Fowlkes-Mallows Index
        metrics['fowlkes_mallows'] = float(np.sqrt(metrics['precision'] * metrics['recall']))

        # Componentes de matriz de confusión
        metrics.update({
            'true_positives': int(tp),
            'false_positives': int(fp),
            'true_negatives': int(tn),
            'false_negatives': int(fn)
        })

        return metrics

    def _calculate_boundary_metrics(self, pred: np.ndarray, gt: np.ndarray) -> Dict[str, float]:
        """Calcular métricas de precisión de contorno."""
        try:
            pred_contours = self._extract_contours(pred)
            gt_contours = self._extract_contours(gt)

            metrics = {}

            if len(pred_contours) > 0 and len(gt_contours) > 0:
                # Boundary IoU aproximado
                pred_boundary = self._dilate_contours(pred, iterations=2)
                gt_boundary = self._dilate_contours(gt, iterations=2)

                boundary_intersection = np.sum((pred_boundary == 1) & (gt_boundary == 1))
                boundary_union = np.sum((pred_boundary == 1) | (gt_boundary == 1))

                metrics['boundary_iou'] = float(boundary_intersection / (boundary_union + self.epsilon))

                # Distancia promedio de contorno
                metrics['average_boundary_distance'] = self._calculate_average_boundary_distance(
                    pred_contours, gt_contours
                )
            else:
                metrics['boundary_iou'] = 0.0
                metrics['average_boundary_distance'] = float('inf')

            return metrics

        except Exception:
            return {'boundary_iou': 0.0, 'average_boundary_distance': float('inf')}

    def _calculate_area_metrics(self, pred: np.ndarray, gt: np.ndarray) -> Dict[str, float]:
        """Calcular métricas relacionadas con área."""
        pred_area = np.sum(pred)
        gt_area = np.sum(gt)
        total_pixels = pred.size

        metrics = {
            'predicted_area_ratio': float(pred_area / total_pixels),
            'ground_truth_area_ratio': float(gt_area / total_pixels),
            'area_ratio_error': float(abs(pred_area - gt_area) / (gt_area + self.epsilon)),
            'relative_area_error': float((pred_area - gt_area) / (gt_area + self.epsilon)),
            'area_correlation': float(np.corrcoef(pred.flatten(), gt.flatten())[0, 1]
                                   if gt_area > 0 else 0.0)
        }

        return metrics

    def _calculate_connectivity_metrics(self, pred: np.ndarray, gt: np.ndarray) -> Dict[str, float]:
        """Calcular métricas de conectividad y topología."""
        try:
            # Número de componentes conectados
            pred_components = cv2.connectedComponents(pred.astype(np.uint8))[0] - 1
            gt_components = cv2.connectedComponents(gt.astype(np.uint8))[0] - 1

            metrics = {
                'predicted_components': int(pred_components),
                'ground_truth_components': int(gt_components),
                'component_count_error': int(abs(pred_components - gt_components))
            }

            # Métrica de fragmentación
            if gt_components > 0:
                metrics['fragmentation_ratio'] = float(pred_components / gt_components)
            else:
                metrics['fragmentation_ratio'] = float('inf') if pred_components > 0 else 1.0

            return metrics

        except Exception:
            return {
                'predicted_components': 0,
                'ground_truth_components': 0,
                'component_count_error': 0,
                'fragmentation_ratio': 1.0
            }

    def _extract_contours(self, mask: np.ndarray) -> List:
        """Extraer contornos de máscara binaria."""
        contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        return contours

    def _dilate_contours(self, mask: np.ndarray, iterations: int = 1) -> np.ndarray:
        """Dilatar contornos para cálculo de boundary IoU."""
        kernel = np.ones((3, 3), np.uint8)
        return cv2.dilate(mask.astype(np.uint8), kernel, iterations=iterations)

    def _calculate_average_boundary_distance(self, pred_contours: List, gt_contours: List) -> float:
        """Calcular distancia promedio entre contornos."""
        try:
            if not pred_contours or not gt_contours:
                return float('inf')

            pred_points = pred_contours[0].reshape(-1, 2) if len(pred_contours) > 0 else np.array([[0, 0]])
            gt_points = gt_contours[0].reshape(-1, 2) if len(gt_contours) > 0 else np.array([[0, 0]])

            # Calcular distancia Hausdorff aproximada
            distances = []
            for pred_point in pred_points[::10]:  # Subsample para eficiencia
                min_dist = np.min(np.linalg.norm(gt_points - pred_point, axis=1))
                distances.append(min_dist)

            return float(np.mean(distances)) if distances else float('inf')

        except Exception:
            return float('inf')

    def _get_empty_metrics(self) -> Dict[str, float]:
        """Retornar métricas vacías en caso de error."""
        return {
            'iou': 0.0, 'dice': 0.0, 'precision': 0.0, 'recall': 0.0,
            'specificity': 0.0, 'f1_score': 0.0, 'pixel_accuracy': 0.0,
            'balanced_accuracy': 0.0, 'mcc': 0.0, 'fowlkes_mallows': 0.0,
            'boundary_iou': 0.0, 'average_boundary_distance': float('inf'),
            'predicted_area_ratio': 0.0, 'ground_truth_area_ratio': 0.0,
            'area_ratio_error': float('inf'), 'relative_area_error': float('inf'),
            'area_correlation': 0.0, 'predicted_components': 0,
            'ground_truth_components': 0, 'component_count_error': 0,
            'fragmentation_ratio': 1.0, 'true_positives': 0,
            'false_positives': 0, 'true_negatives': 0, 'false_negatives': 0
        }

In [19]:
# ============================================================================
# CARGADORES DE DATOS
# ============================================================================

class RobustDataLoader:
    """Cargador de datos con validación robusta."""

    def __init__(self, config: EvaluationConfig):
        """
        Inicializar cargador de datos.

        Args:
            config: Configuración de evaluación validada
        """
        self.config = config
        self.valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
        self.logger = logging.getLogger(__name__)

    def load_images(self) -> List[str]:
        """
        Cargar lista validada de rutas de imágenes.

        Returns:
            Lista ordenada de rutas de imágenes válidas
        """
        self.logger.info(f"Cargando imágenes desde: {self.config.images_directory}")

        image_paths = []

        for filename in os.listdir(self.config.images_directory):
            if self._is_valid_image_file(filename):
                full_path = os.path.join(self.config.images_directory, filename)

                if self._validate_image_file(full_path):
                    image_paths.append(full_path)
                else:
                    self.logger.warning(f"Archivo de imagen inválido omitido: {filename}")

        if not image_paths:
            raise ValueError(f"No se encontraron imágenes válidas en {self.config.images_directory}")

        # Ordenar para reproducibilidad
        image_paths.sort()

        # Aplicar límite si especificado
        if self.config.max_images:
            image_paths = image_paths[:self.config.max_images]
            self.logger.info(f"Limitando evaluación a {self.config.max_images} imágenes")

        self.logger.info(f"Cargadas {len(image_paths)} imágenes válidas")
        return image_paths

    def load_ground_truth(self, image_path: str) -> Optional[np.ndarray]:
        """
        Cargar ground truth correspondiente a una imagen.

        Args:
            image_path: Ruta de la imagen original

        Returns:
            Array numpy con máscara binaria o None si no existe
        """
        if not self.config.ground_truth_directory:
            return None

        try:
            image_name = Path(image_path).stem
            gt_extensions = ['.png', '.jpg', '.jpeg', '.bmp']

            gt_path = None
            for ext in gt_extensions:
                candidate_path = os.path.join(self.config.ground_truth_directory, f"{image_name}{ext}")
                if os.path.exists(candidate_path):
                    gt_path = candidate_path
                    break

            if not gt_path:
                return None

            # Cargar y validar ground truth
            gt_image = Image.open(gt_path).convert('L')

            # Redimensionar si necesario
            if gt_image.size != self.config.target_size[::-1]:
                gt_image = gt_image.resize(self.config.target_size[::-1], Image.NEAREST)

            # Convertir a array binario
            gt_array = np.array(gt_image)
            binary_mask = (gt_array > 127).astype(np.uint8)

            if np.sum(binary_mask) == 0:
                self.logger.warning(f"Ground truth vacío para: {image_name}")

            return binary_mask

        except Exception as e:
            self.logger.warning(f"Error cargando ground truth para {image_path}: {str(e)}")
            return None

    def _is_valid_image_file(self, filename: str) -> bool:
        """Verificar si archivo tiene extensión de imagen válida."""
        return any(filename.lower().endswith(ext) for ext in self.valid_extensions)

    def _validate_image_file(self, filepath: str) -> bool:
        """Validar que archivo de imagen sea accesible y válido."""
        try:
            if not os.path.isfile(filepath):
                return False

            # Intentar abrir imagen para validar formato
            with Image.open(filepath) as img:
                img.verify()

            # Verificar tamaño mínimo
            with Image.open(filepath) as img:
                if img.size[0] < 32 or img.size[1] < 32:
                    self.logger.warning(f"Imagen demasiado pequeña: {filepath}")
                    return False

            return True

        except Exception:
            return False


In [None]:
# ============================================================================
# WRAPPERS DE MODELOS
# ============================================================================

class AbstractModelWrapper(ABC):
    """Clase abstracta para wrappers de modelos de segmentación."""

    @abstractmethod
    def predict(self, image: Image.Image) -> Dict:
        """Realizar predicción en imagen."""
        pass

    @abstractmethod
    def get_model_info(self) -> ModelMetadata:
        """Obtener metadatos del modelo."""
        pass

    @abstractmethod
    def extract_person_mask(self, prediction_outputs: Dict) -> Optional[np.ndarray]:
        """Extraer máscara de persona desde outputs del modelo."""
        pass


class Mask2FormerWrapper(AbstractModelWrapper):
    """Wrapper para Mask2Former."""

    def __init__(self, model_id: str, device: str):
        """Inicializar wrapper de Mask2Former."""
        self.model_id = model_id
        self.device = device
        self.person_class_id = 1  # COCO person class

        # Cargar modelo y procesador
        self.processor = Mask2FormerImageProcessor.from_pretrained(model_id)
        self.model = Mask2FormerForUniversalSegmentation.from_pretrained(model_id)
        self.model.eval()
        self.model = self.model.to(device)

        # Metadatos
        self.parameter_count = sum(p.numel() for p in self.model.parameters())

    def predict(self, image: Image.Image) -> Dict:
        """Realizar predicción con Mask2Former."""
        try:
            with torch.no_grad():
                inputs = self.processor(images=image, return_tensors="pt")
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                outputs = self.model(**inputs)

            # Procesar salidas
            processed_outputs = {}

            # Segmentación semántica
            semantic_map = self.processor.post_process_semantic_segmentation(
                outputs, target_sizes=[image.size[::-1]]
            )[0]
            processed_outputs['semantic_segmentation'] = semantic_map.cpu().numpy()

            # Intentar segmentación panóptica
            try:
                panoptic_map = self.processor.post_process_panoptic_segmentation(
                    outputs, target_sizes=[image.size[::-1]]
                )[0]
                processed_outputs['panoptic_segmentation'] = {
                    'segmentation': panoptic_map['segmentation'].cpu().numpy(),
                    'segments_info': panoptic_map['segments_info']
                }
            except:
                pass

            return processed_outputs

        except Exception as e:
            logging.error(f"Error en predicción Mask2Former: {str(e)}")
            return {}

    def extract_person_mask(self, prediction_outputs: Dict) -> Optional[np.ndarray]:
        """Extraer máscara de persona desde outputs de Mask2Former."""
        try:
            # Priorizar segmentación panóptica
            if 'panoptic_segmentation' in prediction_outputs:
                panoptic_data = prediction_outputs['panoptic_segmentation']
                segmentation_map = panoptic_data['segmentation']
                segments_info = panoptic_data['segments_info']

                person_mask = np.zeros_like(segmentation_map, dtype=np.uint8)
                for segment in segments_info:
                    if segment.get('category_id') == self.person_class_id:
                        person_mask[segmentation_map == segment['id']] = 1

                return person_mask

            # Fallback a segmentación semántica
            elif 'semantic_segmentation' in prediction_outputs:
                semantic_map = prediction_outputs['semantic_segmentation']
                person_mask = (semantic_map == self.person_class_id).astype(np.uint8)
                return person_mask

            return None

        except Exception as e:
            logging.warning(f"Error extrayendo máscara de persona: {str(e)}")
            return None

    def get_model_info(self) -> ModelMetadata:
        """Obtener metadatos de Mask2Former."""
        import transformers

        return ModelMetadata(
            model_id=self.model_id,
            model_name="Mask2Former",
            architecture_type="mask2former",
            supports_panoptic=True,
            parameter_count=self.parameter_count,
            device_used=self.device,
            pytorch_version=torch.__version__,
            transformers_version=transformers.__version__
        )


class OneFormerWrapper(AbstractModelWrapper):
    """Wrapper para OneFormer."""

    def __init__(self, model_id: str, device: str):
        """Inicializar wrapper de OneFormer."""
        self.model_id = model_id
        self.device = device
        self.person_class_id = 1  # COCO person class

        # Cargar modelo y procesador
        self.processor = OneFormerImageProcessor.from_pretrained(model_id)
        self.model = OneFormerForUniversalSegmentation.from_pretrained(model_id)
        self.model.eval()
        self.model = self.model.to(device)

        # Metadatos
        self.parameter_count = sum(p.numel() for p in self.model.parameters())

    def predict(self, image: Image.Image) -> Dict:
        """Realizar predicción con OneFormer."""
        try:
            with torch.no_grad():
                inputs = self.processor(images=image, return_tensors="pt")
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                outputs = self.model(**inputs)

            # Procesar salidas específicas de OneFormer
            processed_outputs = {}

            # Segmentación semántica
            try:
                semantic_map = self.processor.post_process_semantic_segmentation(
                    outputs, target_sizes=[image.size[::-1]]
                )[0]
                processed_outputs['semantic_segmentation'] = semantic_map.cpu().numpy()
            except:
                pass

            # Segmentación panóptica (especialidad de OneFormer)
            try:
                panoptic_map = self.processor.post_process_panoptic_segmentation(
                    outputs, target_sizes=[image.size[::-1]]
                )[0]
                processed_outputs['panoptic_segmentation'] = {
                    'segmentation': panoptic_map['segmentation'].cpu().numpy(),
                    'segments_info': panoptic_map['segments_info']
                }
            except:
                pass

            return processed_outputs

        except Exception as e:
            logging.error(f"Error en predicción OneFormer: {str(e)}")
            return {}