In [None]:
from sklearn.metrics import roc_auc_score

class SegmentationEvaluator:
    def __init__(self, model, test_loader, threshold=0.6, device=None):
        self.model = model.to(device)
        self.test_loader = test_loader
        self.threshold = threshold
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Diccionario de métricas
        self.metrics_funcs = {
            "Dice Coefficient": self.dice_coefficient,
            "IoU Score": self.iou_score,
            "Precision": self.precision,
            "Sensitivity (Recall)": self.recall,
            "Specificity": self.specificity,
            "Accuracy": self.accuracy,
            "F1-Score": self.f1_score,
            "AUC": self.auc_score,
            "Hausdorff Distance": self.hausdorff_distance  # new metric 2025 fer
        }

    @staticmethod
    def dice_coefficient(preds, labels):
        preds = preds.float()
        labels = labels.float()
        intersection = (preds * labels).sum()
        union = preds.sum() + labels.sum()
        return (2.0 * intersection / (union + 1e-8)).item()

    @staticmethod
    def iou_score(preds, labels):
        preds = preds.float()
        labels = labels.float()
        intersection = (preds * labels).sum()
        union = preds.sum() + labels.sum() - intersection
        return (intersection / (union + 1e-8)).item()

    @staticmethod
    def precision(preds, labels):
        preds = preds.float()
        labels = labels.float()
        tp = (preds * labels).sum()
        fp = (preds * (1 - labels)).sum()
        return (tp / (tp + fp + 1e-8)).item()

    @staticmethod
    def recall(preds, labels):
        preds = preds.float()
        labels = labels.float()
        tp = (preds * labels).sum()
        fn = ((1 - preds) * labels).sum()
        return (tp / (tp + fn + 1e-8)).item()

    @staticmethod
    def specificity(preds, labels):
        preds = preds.float()
        labels = labels.float()
        tn = ((1 - preds) * (1 - labels)).sum()
        fp = (preds * (1 - labels)).sum()
        return (tn / (tn + fp + 1e-8)).item()

    @staticmethod
    def accuracy(preds, labels):
        preds = preds.float()
        labels = labels.float()
        correct = (preds == labels).sum()
        total = labels.numel()
        return (correct / total).item()

    @staticmethod
    def f1_score(preds, labels):
        precision = SegmentationEvaluator.precision(preds, labels)
        recall = SegmentationEvaluator.recall(preds, labels)
        return 2 * precision * recall / (precision + recall + 1e-8)

    @staticmethod
    def auc_score(preds, labels):
        preds = preds.view(-1).cpu().numpy()
        labels = labels.view(-1).cpu().numpy()
    
        # Verificar si hay al menos dos clases presentes
        if len(np.unique(labels)) < 2:
            return 0.5  # AUC no computable, predeterminado a 0.5
    
        try:
            return roc_auc_score(labels, preds)
        except ValueError:
            # Si ocurre otro error, devuelve 0.5
            return 0.5
    @staticmethod
    def hausdorff_distance(preds, labels):
        preds_np = preds.cpu().numpy().astype(np.bool_)
        labels_np = labels.cpu().numpy().astype(np.bool_)
    
        distances = []
        for pred, label in zip(preds_np, labels_np):
            pred_coords = np.column_stack(np.where(pred[0]))
            label_coords = np.column_stack(np.where(label[0]))
    
            if pred_coords.size == 0 or label_coords.size == 0:
                distances.append(0.0)
            else:
                forward_hd = directed_hausdorff(pred_coords, label_coords)[0]
                backward_hd = directed_hausdorff(label_coords, pred_coords)[0]
                distances.append(max(forward_hd, backward_hd))
    
        return np.mean(distances)
        
    def evaluate(self, visualize=False, num_images=2):
        self.model.eval()
        metrics = {key: 0 for key in self.metrics_funcs.keys()}
        num_batches = len(self.test_loader)

        all_images = []
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                #print("MAx of outputs", np.max(outputs))
                preds_bin = torch.sigmoid(outputs)
                #print("Max of preds bin", np.max(preds_bin))
                preds_bin = (outputs > self.threshold).float()

                for key, func in self.metrics_funcs.items():
                    if key == "AUC":
                        metrics[key] += func(preds_bin, labels)  # AUC usa las predicciones continuas
                    else:
                        metrics[key] += func(preds_bin, labels)

                if visualize:
                    all_images.append(images.cpu())
                    all_preds.append(preds_bin.cpu())
                    all_labels.append(labels.cpu())

        metrics = {k: v / num_batches for k, v in metrics.items()}

        if visualize:
            self.visualize_results(all_images, all_preds, all_labels, num_images)

        print(metrics)
        #wandb.log(metrics)
        return metrics

    @staticmethod
    def visualize_results(images, preds, labels, num_images=2):
        images = torch.cat(images)[:num_images]
        preds = torch.cat(preds)[:num_images]
        labels = torch.cat(labels)[:num_images]

        num_rows = num_images
        plt.figure(figsize=(10, 3 * num_rows))

        for i in range(num_images):
            image = images[i].squeeze(0).numpy() 
            pred = preds[i].squeeze(0).numpy()
            label = labels[i].squeeze(0).numpy()

            plt.subplot(num_rows, 3, i * 3 + 1)
            plt.imshow(image, cmap="gray")
            plt.title("Input Image")
            plt.axis("off")

            plt.subplot(num_rows, 3, i * 3 + 2)
            plt.imshow(pred, cmap="gray")
            plt.title("Predicted Mask")
            plt.axis("off")

            plt.subplot(num_rows, 3, i * 3 + 3)
            plt.imshow(label, cmap="gray")
            plt.title("Ground Truth Mask")
            plt.axis("off")

        plt.tight_layout()
        plt.show()