In [None]:
from sklearn.metrics import roc_auc_score
from scipy.spatial.distance import directed_hausdorff

class SegmentationEvaluator:
    def __init__(self, model, test_loader, threshold=0.6, device=None):
        self.model = model.to(device)
        self.test_loader = test_loader
        self.threshold = threshold
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Diccionario de métricas
        self.metrics_funcs = {
            "Dice Coefficient": self.dice_coefficient,
            "IoU Score": self.iou_score,
            "Precision": self.precision,
            "Sensitivity (Recall)": self.recall,
            "Specificity": self.specificity,
            "Accuracy": self.accuracy,
            "F1-Score": self.f1_score,
            #"AUC": self.auc_score
            "Hausdorff":self.hausdorff_distance
        }

    @staticmethod
    def dice_coefficient(preds, labels):
        preds = preds.float()
        labels = labels.float()
        intersection = (preds * labels).sum()
        union = preds.sum() + labels.sum()
        return (2.0 * intersection / (union + 1e-8)).item()

    @staticmethod
    def iou_score(preds, labels):
        preds = (preds).float()
        labels = labels.float()
        intersection = (preds * labels).sum()
        union = preds.sum() + labels.sum() - intersection
        return (intersection / (union + 1e-8)).item()

    @staticmethod
    def precision(preds, labels):
        preds = preds.float()
        labels = labels.float()
        tp = (preds * labels).sum()
        fp = (preds * (1 - labels)).sum()
        return (tp / (tp + fp + 1e-8)).item()

    @staticmethod
    def recall(preds, labels):
        """
        Calcula la sensibilidad (recall).
        """
        preds = preds.float()
        labels = labels.float()
        tp = (preds * labels).sum()
        fn = ((1 - preds) * labels).sum()
        return (tp / (tp + fn + 1e-8)).item()

    @staticmethod
    def specificity(preds, labels):
        preds = preds.float()
        labels = labels.float()
        tn = ((1 - preds) * (1 - labels)).sum()
        fp = (preds * (1 - labels)).sum()
        return (tn / (tn + fp + 1e-8)).item()

    @staticmethod
    def accuracy(preds, labels):
        preds = preds.float()
        labels = labels.float()
        correct = (preds == labels).sum()
        total = labels.numel()
        return (correct / total).item()

    @staticmethod
    def f1_score(preds, labels):
        """
        Calcula el F1-Score basado en precisión y recall.
        """
        precision = SegmentationEvaluator.precision(preds, labels)
        recall = SegmentationEvaluator.recall(preds, labels)
        return 2 * precision * recall / (precision + recall + 1e-8)
    

    @staticmethod
    def auc_score(preds_prob, labels):
        """
        Calcula AUC a nivel de imagen o batch.
        preds_prob: probabilidades (sigmoid)
        labels: binarios
        """
        preds = preds_prob.view(-1).cpu().numpy()
        labels = labels.view(-1).cpu().numpy()
    
        # Verificar que haya al menos un pixel positivo y uno negativo
        if np.any(labels == 1) and np.any(labels == 0):
            return roc_auc_score(labels, preds)
        else:
            return np.nan  # Ignorar batch si no hay clases completas
    
    @staticmethod
    def hausdorff_distance(preds, labels):
        """
        Calcula la distancia de Hausdorff entre dos máscaras binarias.
        preds y labels: tensores binarios (0 o 1) de forma [H, W]
        """
        preds = preds.cpu().numpy()
        labels = labels.cpu().numpy()
    
        # Obtener las coordenadas de los píxeles donde hay 1
        pred_points = np.argwhere(preds == 1)
        label_points = np.argwhere(labels == 1)
    
        if len(pred_points) == 0 or len(label_points) == 0:
            return float('inf')  # Si alguna máscara está vacía
    
        # Distancias dirigidas
        d_pred_label = directed_hausdorff(pred_points, label_points)[0]
        d_label_pred = directed_hausdorff(label_points, pred_points)[0]
    
        # Hausdorff simétrica
        return max(d_pred_label, d_label_pred)


    def evaluate(self, visualize=False, num_images=2, save_path=None):
        self.model.eval()
        metrics = {key: 0 for key in self.metrics_funcs.keys()}
        num_batches = len(self.test_loader)

        all_images = []
        all_preds = []
        all_labels = []
        print("Segmentation Evaluator/evaluate using threshold: ", self.threshold)
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                preds_prob = torch.sigmoid(outputs)
                preds_bin = (outputs > self.threshold).float()
                
                #plt.imshow(outputs[0,...].cpu(), cmap='gray')
                #plt.show()
                
                for key, func in self.metrics_funcs.items():
                    if key in ['FPS', 'latency']:
                        continue
                    if key == "AUC":
                        metrics[key] += func(outputs, labels)  # AUC usa las predicciones continuas
                    else:
                        metrics[key] += func(preds_bin, labels)

                if visualize:
                    all_images.append(images.cpu())
                    all_preds.append(preds_bin.cpu())
                    all_labels.append(labels.cpu())

        metrics = {k: v / num_batches for k, v in metrics.items()}
        fps,ms_per_img=  fps_on_loader(self.model, self.test_loader) #se modificpo
        avg_fps, std_fps, avg_latency, std_latency = measure_fps_repetitions(self.model, self.test_loader) # se modifico
        
        inference_time={
            'FPS'         : fps,
            'ms_per_img'  : ms_per_img,
            'avg_fps'     : avg_fps,
            'std_fps'     : std_fps,
            'avg_latency' : avg_latency,
            'std_latency' : std_latency
        }
        if visualize:
            self.visualize_results(all_images, all_preds, all_labels, num_images, save_path)

        metrics.update(inference_time)
        print(metrics)
        print(inference_time)
        wandb.log(metrics)
        return metrics

    @staticmethod
    def visualize_results(images, preds, labels, num_images=2, save_path=None):
        images = torch.cat(images)[:num_images]
        preds = torch.cat(preds)[:num_images]
        labels = torch.cat(labels)[:num_images]
        
        num_rows = num_images
        plt.figure(figsize=(10, 3 * num_rows))
        
        for i in range(num_images):
            image = images[i].squeeze(0).numpy()  # Remove channel dimension
            pred = preds[i].squeeze(0).numpy()
            label = labels[i].squeeze(0).numpy()
        
            plt.subplot(num_rows, 3, i * 3 + 1)
            plt.imshow(image, cmap="gray")
            plt.title("Input Image")
            plt.axis("off")
        
            plt.subplot(num_rows, 3, i * 3 + 2)
            plt.imshow(pred, cmap="gray")
            plt.title("Predicted Mask")
            plt.axis("off")
        
            plt.subplot(num_rows, 3, i * 3 + 3)
            plt.imshow(label, cmap="gray")
            plt.title("Ground Truth Mask")
            plt.axis("off")
        
        plt.tight_layout()
        
        if save_path:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            plt.savefig(save_path)
            print(f"Imagen guardada en: {save_path}")
        
        plt.show()

In [None]:
def fps_on_loader(model, test_loader, device="cuda", warmup_batches=3):
    model.eval()
    torch.set_grad_enabled(False)
    model.to(device)

    # ---------- WARM-UP ----------
    for i, (x, _) in enumerate(test_loader):
        if i == warmup_batches:
            break
        _ = model(x.to(device, non_blocking=True))

    # ---------- MEDICIÓN ----------
    total_imgs = 0
    if device.startswith("cuda") and torch.cuda.is_available():
        starter = torch.cuda.Event(enable_timing=True)
        ender   = torch.cuda.Event(enable_timing=True)
        elapsed_ms = 0.0

        for x, _ in test_loader:
            x = x.to(device, non_blocking=True)
            torch.cuda.synchronize()
            starter.record()

            _ = model(x)

            ender.record()
            torch.cuda.synchronize()
            elapsed_ms += starter.elapsed_time(ender)
            total_imgs += x.size(0)
    else:  #  CPU (o GPU sin eventos)
        tic = time.perf_counter()
        for x, _ in test_loader:
            x = x.to(device)
            _ = model(x)
            total_imgs += x.size(0)
        elapsed_ms = (time.perf_counter() - tic) * 1_000

    fps = 1_000 * total_imgs / elapsed_ms
    ms_per_img = elapsed_ms / total_imgs
    return fps, ms_per_img

def measure_fps_repetitions(model, test_loader, device="cuda", warmup_batches=3, repetitions=5):
    fps_values = []
    latency_values = []

    # Realizar varias repeticiones
    for _ in range(repetitions):
        fps, latency = fps_on_loader(model, test_loader, device, warmup_batches)
        fps_values.append(fps)
        latency_values.append(latency)

    # Calcular el promedio y la desviación estándar
    avg_fps = np.mean(fps_values)
    std_fps = np.std(fps_values)
    avg_latency = np.mean(latency_values)
    std_latency = np.std(latency_values)

    return avg_fps, std_fps, avg_latency, std_latency